From 4d6af280c76ff7d266434f2dba207c4b75dfc076 Mon Sep 17 00:00:00 2001 From: Di Xu Date: Mon, 9 Oct 2017 22:44:33 -0500 Subject: [PATCH] enforce required flags (#502) --- command.go | 22 ++++++++++++++++++++++ command_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/command.go b/command.go index a23bb6d9..87d0791c 100644 --- a/command.go +++ b/command.go @@ -693,6 +693,9 @@ func (c *Command) execute(a []string) (err error) { c.PreRun(c, argWoFlags) } + if err := c.validateRequiredFlags(); err != nil { + return err + } if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { return err @@ -810,6 +813,25 @@ func (c *Command) ValidateArgs(args []string) error { return c.Args(c, args) } +func (c *Command) validateRequiredFlags() error { + flags := c.Flags() + missingFlagNames := []string{} + flags.VisitAll(func(pflag *flag.Flag) { + requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag] + if !found { + return + } + if (requiredAnnotation[0] == "true") && !pflag.Changed { + missingFlagNames = append(missingFlagNames, pflag.Name) + } + }) + + if len(missingFlagNames) > 0 { + return fmt.Errorf(`Required flag(s) "%s" have/has not been set`, strings.Join(missingFlagNames, `", "`)) + } + return nil +} + // InitDefaultHelpFlag adds default help flag to c. // It is called automatically by executing the c or by calling help and usage. // If c already has help flag, it will do nothing. diff --git a/command_test.go b/command_test.go index 938a7417..0deb87c7 100644 --- a/command_test.go +++ b/command_test.go @@ -438,3 +438,51 @@ func TestTraverseWithBadChildFlag(t *testing.T) { t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name()) } } + +func TestRequiredFlags(t *testing.T) { + c := &Command{Use: "c", Run: func(*Command, []string) {}} + output := new(bytes.Buffer) + c.SetOutput(output) + c.Flags().String("foo1", "", "required foo1") + c.MarkFlagRequired("foo1") + c.Flags().String("foo2", "", "required foo2") + c.MarkFlagRequired("foo2") + c.Flags().String("bar", "", "optional bar") + + expected := fmt.Sprintf("Required flag(s) %q, %q have/has not been set", "foo1", "foo2") + + if err := c.Execute(); err != nil { + if err.Error() != expected { + t.Errorf("expected %v, got %v", expected, err.Error()) + } + } +} + +func TestPersistentRequiredFlags(t *testing.T) { + parent := &Command{Use: "parent", Run: func(*Command, []string) {}} + output := new(bytes.Buffer) + parent.SetOutput(output) + parent.PersistentFlags().String("foo1", "", "required foo1") + parent.MarkPersistentFlagRequired("foo1") + parent.PersistentFlags().String("foo2", "", "required foo2") + parent.MarkPersistentFlagRequired("foo2") + parent.Flags().String("foo3", "", "optional foo3") + + child := &Command{Use: "child", Run: func(*Command, []string) {}} + child.Flags().String("bar1", "", "required bar1") + child.MarkFlagRequired("bar1") + child.Flags().String("bar2", "", "required bar2") + child.MarkFlagRequired("bar2") + child.Flags().String("bar3", "", "optional bar3") + + parent.AddCommand(child) + parent.SetArgs([]string{"child"}) + + expected := fmt.Sprintf("Required flag(s) %q, %q, %q, %q have/has not been set", "bar1", "bar2", "foo1", "foo2") + + if err := parent.Execute(); err != nil { + if err.Error() != expected { + t.Errorf("expected %v, got %v", expected, err.Error()) + } + } +}