From aa02e412acee718e8e556d367cbff73960157b96 Mon Sep 17 00:00:00 2001 From: Niels Claeys Date: Wed, 13 Oct 2021 16:43:52 +0200 Subject: [PATCH] Handle requiredFlags errors in the same way as parsed flag errors. At the moment they are handled differently, which means that you will get inconsistent output when you do not pass a required flag compared to when you pass in an incorrect flag. --- command.go | 2 +- command_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/command.go b/command.go index 2cc18891..59482658 100644 --- a/command.go +++ b/command.go @@ -850,7 +850,7 @@ func (c *Command) execute(a []string) (err error) { } if err := c.validateRequiredFlags(); err != nil { - return err + return c.FlagErrorFunc()(c, err) } if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { diff --git a/command_test.go b/command_test.go index 583cb023..ffef46ae 100644 --- a/command_test.go +++ b/command_test.go @@ -3,6 +3,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "os" @@ -781,7 +782,6 @@ func TestRequiredFlags(t *testing.T) { c.Flags().String("foo2", "", "") assertNoErr(t, c.MarkFlagRequired("foo2")) c.Flags().String("bar", "", "") - expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2") _, err := executeCommand(c) @@ -792,6 +792,58 @@ func TestRequiredFlags(t *testing.T) { } } +func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) { + usageFunc := func(c *Command) error { + c.Println("usage string") + return nil + } + c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc} + c.Flags().String("foo1", "", "") + assertNoErr(t, c.MarkFlagRequired("foo1")) + silentError := "failed flag parsing" + c.SetFlagErrorFunc(func(c *Command, err error) error { + c.Println(err) + c.Println(c.UsageString()) + return errors.New(silentError) + }) + requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1") + + output, err := executeCommand(c) + + got := err.Error() + checkStringContains(t, output, requiredFlagErrorMessage) + checkStringContains(t, output, c.UsageString()) + if got != silentError { + t.Errorf("Expected error %s but got %s", silentError, got) + } +} + +func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) { + usageFunc := func(c *Command) error { + c.Println("usage string") + return nil + } + c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc} + c.Flags().String("foo1", "", "") + assertNoErr(t, c.MarkFlagRequired("foo1")) + silentError := "failed flag parsing" + c.SetFlagErrorFunc(func(c *Command, err error) error { + c.Println(err) + c.Println(c.UsageString()) + return errors.New(silentError) + }) + unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag") + + output, err := executeCommand(c, "--unknownflag") + + got := err.Error() + checkStringContains(t, output, unknownFlagErrorMessage) + checkStringContains(t, output, c.UsageString()) + if got != silentError { + t.Errorf("Expected error %s but got %s", silentError, got) + } +} + func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "")