mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Improving the required flags error by using the pluralize util
This commit is contained in:
parent
02a0d2fbc9
commit
75499b99cb
2 changed files with 29 additions and 3 deletions
13
command.go
13
command.go
|
@ -1004,7 +1004,10 @@ func (c *Command) validateRequiredFlags() error {
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(missingFlagNames) > 0 {
|
if len(missingFlagNames) > 0 {
|
||||||
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
|
return fmt.Errorf(`required %s "%s" not set`,
|
||||||
|
pluralize("flag", len(missingFlagNames)),
|
||||||
|
strings.Join(missingFlagNames, `", "`),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1662,3 +1665,11 @@ func (c *Command) updateParentsPflags() {
|
||||||
c.parentsPflags.AddFlagSet(parent.PersistentFlags())
|
c.parentsPflags.AddFlagSet(parent.PersistentFlags())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func pluralize(name string, length int) string {
|
||||||
|
if length == 1 {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return name + "s"
|
||||||
|
}
|
||||||
|
|
|
@ -742,6 +742,21 @@ func TestPersistentFlagsOnChild(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequiredFlag(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: emptyRun}
|
||||||
|
c.Flags().String("foo1", "", "")
|
||||||
|
c.MarkFlagRequired("foo1")
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("required flag %q not set", "foo1")
|
||||||
|
|
||||||
|
_, err := executeCommand(c)
|
||||||
|
got := err.Error()
|
||||||
|
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Expected error: %q, got: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRequiredFlags(t *testing.T) {
|
func TestRequiredFlags(t *testing.T) {
|
||||||
c := &Command{Use: "c", Run: emptyRun}
|
c := &Command{Use: "c", Run: emptyRun}
|
||||||
c.Flags().String("foo1", "", "")
|
c.Flags().String("foo1", "", "")
|
||||||
|
@ -750,7 +765,7 @@ func TestRequiredFlags(t *testing.T) {
|
||||||
c.MarkFlagRequired("foo2")
|
c.MarkFlagRequired("foo2")
|
||||||
c.Flags().String("bar", "", "")
|
c.Flags().String("bar", "", "")
|
||||||
|
|
||||||
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
|
expected := fmt.Sprintf("required flags %q, %q not set", "foo1", "foo2")
|
||||||
|
|
||||||
_, err := executeCommand(c)
|
_, err := executeCommand(c)
|
||||||
got := err.Error()
|
got := err.Error()
|
||||||
|
@ -777,7 +792,7 @@ func TestPersistentRequiredFlags(t *testing.T) {
|
||||||
|
|
||||||
parent.AddCommand(child)
|
parent.AddCommand(child)
|
||||||
|
|
||||||
expected := fmt.Sprintf("required flag(s) %q, %q, %q, %q not set", "bar1", "bar2", "foo1", "foo2")
|
expected := fmt.Sprintf("required flags %q, %q, %q, %q not set", "bar1", "bar2", "foo1", "foo2")
|
||||||
|
|
||||||
_, err := executeCommand(parent, "child")
|
_, err := executeCommand(parent, "child")
|
||||||
if err.Error() != expected {
|
if err.Error() != expected {
|
||||||
|
|
Loading…
Reference in a new issue