mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
enforce required flags (#502)
This commit is contained in:
parent
50204810fd
commit
4d6af280c7
2 changed files with 70 additions and 0 deletions
22
command.go
22
command.go
|
@ -693,6 +693,9 @@ func (c *Command) execute(a []string) (err error) {
|
||||||
c.PreRun(c, argWoFlags)
|
c.PreRun(c, argWoFlags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.validateRequiredFlags(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if c.RunE != nil {
|
if c.RunE != nil {
|
||||||
if err := c.RunE(c, argWoFlags); err != nil {
|
if err := c.RunE(c, argWoFlags); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -810,6 +813,25 @@ func (c *Command) ValidateArgs(args []string) error {
|
||||||
return c.Args(c, args)
|
return c.Args(c, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Command) validateRequiredFlags() error {
|
||||||
|
flags := c.Flags()
|
||||||
|
missingFlagNames := []string{}
|
||||||
|
flags.VisitAll(func(pflag *flag.Flag) {
|
||||||
|
requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (requiredAnnotation[0] == "true") && !pflag.Changed {
|
||||||
|
missingFlagNames = append(missingFlagNames, pflag.Name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(missingFlagNames) > 0 {
|
||||||
|
return fmt.Errorf(`Required flag(s) "%s" have/has not been set`, strings.Join(missingFlagNames, `", "`))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// InitDefaultHelpFlag adds default help flag to c.
|
// InitDefaultHelpFlag adds default help flag to c.
|
||||||
// It is called automatically by executing the c or by calling help and usage.
|
// It is called automatically by executing the c or by calling help and usage.
|
||||||
// If c already has help flag, it will do nothing.
|
// If c already has help flag, it will do nothing.
|
||||||
|
|
|
@ -438,3 +438,51 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
|
||||||
t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name())
|
t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequiredFlags(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: func(*Command, []string) {}}
|
||||||
|
output := new(bytes.Buffer)
|
||||||
|
c.SetOutput(output)
|
||||||
|
c.Flags().String("foo1", "", "required foo1")
|
||||||
|
c.MarkFlagRequired("foo1")
|
||||||
|
c.Flags().String("foo2", "", "required foo2")
|
||||||
|
c.MarkFlagRequired("foo2")
|
||||||
|
c.Flags().String("bar", "", "optional bar")
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("Required flag(s) %q, %q have/has not been set", "foo1", "foo2")
|
||||||
|
|
||||||
|
if err := c.Execute(); err != nil {
|
||||||
|
if err.Error() != expected {
|
||||||
|
t.Errorf("expected %v, got %v", expected, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistentRequiredFlags(t *testing.T) {
|
||||||
|
parent := &Command{Use: "parent", Run: func(*Command, []string) {}}
|
||||||
|
output := new(bytes.Buffer)
|
||||||
|
parent.SetOutput(output)
|
||||||
|
parent.PersistentFlags().String("foo1", "", "required foo1")
|
||||||
|
parent.MarkPersistentFlagRequired("foo1")
|
||||||
|
parent.PersistentFlags().String("foo2", "", "required foo2")
|
||||||
|
parent.MarkPersistentFlagRequired("foo2")
|
||||||
|
parent.Flags().String("foo3", "", "optional foo3")
|
||||||
|
|
||||||
|
child := &Command{Use: "child", Run: func(*Command, []string) {}}
|
||||||
|
child.Flags().String("bar1", "", "required bar1")
|
||||||
|
child.MarkFlagRequired("bar1")
|
||||||
|
child.Flags().String("bar2", "", "required bar2")
|
||||||
|
child.MarkFlagRequired("bar2")
|
||||||
|
child.Flags().String("bar3", "", "optional bar3")
|
||||||
|
|
||||||
|
parent.AddCommand(child)
|
||||||
|
parent.SetArgs([]string{"child"})
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("Required flag(s) %q, %q, %q, %q have/has not been set", "bar1", "bar2", "foo1", "foo2")
|
||||||
|
|
||||||
|
if err := parent.Execute(); err != nil {
|
||||||
|
if err.Error() != expected {
|
||||||
|
t.Errorf("expected %v, got %v", expected, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue