mirror of
https://github.com/spf13/cobra
synced 2024-11-16 18:57:08 +00:00
Handle requiredFlags errors in the same way as parsed flag errors.
At the moment they are handled differently, which means that you will get inconsistent output when you do not pass a required flag compared to when you pass in an incorrect flag.
This commit is contained in:
parent
c1973d31bf
commit
aa02e412ac
2 changed files with 54 additions and 2 deletions
|
@ -850,7 +850,7 @@ func (c *Command) execute(a []string) (err error) {
|
|||
}
|
||||
|
||||
if err := c.validateRequiredFlags(); err != nil {
|
||||
return err
|
||||
return c.FlagErrorFunc()(c, err)
|
||||
}
|
||||
if c.RunE != nil {
|
||||
if err := c.RunE(c, argWoFlags); err != nil {
|
||||
|
|
|
@ -3,6 +3,7 @@ package cobra
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -781,7 +782,6 @@ func TestRequiredFlags(t *testing.T) {
|
|||
c.Flags().String("foo2", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo2"))
|
||||
c.Flags().String("bar", "", "")
|
||||
|
||||
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
|
||||
|
||||
_, err := executeCommand(c)
|
||||
|
@ -792,6 +792,58 @@ func TestRequiredFlags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||
usageFunc := func(c *Command) error {
|
||||
c.Println("usage string")
|
||||
return nil
|
||||
}
|
||||
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||
c.Flags().String("foo1", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||
silentError := "failed flag parsing"
|
||||
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||
c.Println(err)
|
||||
c.Println(c.UsageString())
|
||||
return errors.New(silentError)
|
||||
})
|
||||
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")
|
||||
|
||||
output, err := executeCommand(c)
|
||||
|
||||
got := err.Error()
|
||||
checkStringContains(t, output, requiredFlagErrorMessage)
|
||||
checkStringContains(t, output, c.UsageString())
|
||||
if got != silentError {
|
||||
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||
usageFunc := func(c *Command) error {
|
||||
c.Println("usage string")
|
||||
return nil
|
||||
}
|
||||
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||
c.Flags().String("foo1", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||
silentError := "failed flag parsing"
|
||||
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||
c.Println(err)
|
||||
c.Println(c.UsageString())
|
||||
return errors.New(silentError)
|
||||
})
|
||||
unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag")
|
||||
|
||||
output, err := executeCommand(c, "--unknownflag")
|
||||
|
||||
got := err.Error()
|
||||
checkStringContains(t, output, unknownFlagErrorMessage)
|
||||
checkStringContains(t, output, c.UsageString())
|
||||
if got != silentError {
|
||||
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistentRequiredFlags(t *testing.T) {
|
||||
parent := &Command{Use: "parent", Run: emptyRun}
|
||||
parent.PersistentFlags().String("foo1", "", "")
|
||||
|
|
Loading…
Reference in a new issue