diff --git a/command.go b/command.go index 2cc18891..59482658 100644 --- a/command.go +++ b/command.go @@ -850,7 +850,7 @@ func (c *Command) execute(a []string) (err error) { } if err := c.validateRequiredFlags(); err != nil { - return err + return c.FlagErrorFunc()(c, err) } if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { diff --git a/command_test.go b/command_test.go index 583cb023..ffef46ae 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -781,7 +782,6 @@ func TestRequiredFlags(t *testing.T) { c.Flags().String("foo2", "", "") assertNoErr(t, c.MarkFlagRequired("foo2")) c.Flags().String("bar", "", "") - expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2") _, err := executeCommand(c) @@ -792,6 +792,58 @@ func TestRequiredFlags(t *testing.T) { } } +func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) { + usageFunc := func(c *Command) error { + c.Println("usage string") + return nil + } + c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc} + c.Flags().String("foo1", "", "") + assertNoErr(t, c.MarkFlagRequired("foo1")) + silentError := "failed flag parsing" + c.SetFlagErrorFunc(func(c *Command, err error) error { + c.Println(err) + c.Println(c.UsageString()) + return errors.New(silentError) + }) + requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1") + + output, err := executeCommand(c) + + got := err.Error() + checkStringContains(t, output, requiredFlagErrorMessage) + checkStringContains(t, output, c.UsageString()) + if got != silentError { + t.Errorf("Expected error %s but got %s", silentError, got) + } +} + +func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) { + usageFunc := func(c *Command) error { + c.Println("usage string") + return nil + } + c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc} + c.Flags().String("foo1", "", "") + assertNoErr(t, c.MarkFlagRequired("foo1")) + silentError := "failed flag parsing" + c.SetFlagErrorFunc(func(c *Command, err error) error { + c.Println(err) + c.Println(c.UsageString()) + return errors.New(silentError) + }) + unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag") + + output, err := executeCommand(c, "--unknownflag") + + got := err.Error() + checkStringContains(t, output, unknownFlagErrorMessage) + checkStringContains(t, output, c.UsageString()) + if got != silentError { + t.Errorf("Expected error %s but got %s", silentError, got) + } +} + func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "")