Handle requiredFlags errors in the same way as parsed flag errors.

At the moment they are handled differently, which means that you will get inconsistent output when you do not pass a required flag compared to when you pass in an incorrect flag.
This commit is contained in:
Niels Claeys 2021-10-13 16:43:52 +02:00 committed by Niels Claeys
parent c1973d31bf
commit aa02e412ac
2 changed files with 54 additions and 2 deletions

View file

@ -850,7 +850,7 @@ func (c *Command) execute(a []string) (err error) {
}
if err := c.validateRequiredFlags(); err != nil {
return err
return c.FlagErrorFunc()(c, err)
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {

View file

@ -3,6 +3,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"os"
@ -781,7 +782,6 @@ func TestRequiredFlags(t *testing.T) {
c.Flags().String("foo2", "", "")
assertNoErr(t, c.MarkFlagRequired("foo2"))
c.Flags().String("bar", "", "")
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
_, err := executeCommand(c)
@ -792,6 +792,58 @@ func TestRequiredFlags(t *testing.T) {
}
}
func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
usageFunc := func(c *Command) error {
c.Println("usage string")
return nil
}
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
c.Flags().String("foo1", "", "")
assertNoErr(t, c.MarkFlagRequired("foo1"))
silentError := "failed flag parsing"
c.SetFlagErrorFunc(func(c *Command, err error) error {
c.Println(err)
c.Println(c.UsageString())
return errors.New(silentError)
})
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")
output, err := executeCommand(c)
got := err.Error()
checkStringContains(t, output, requiredFlagErrorMessage)
checkStringContains(t, output, c.UsageString())
if got != silentError {
t.Errorf("Expected error %s but got %s", silentError, got)
}
}
func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) {
usageFunc := func(c *Command) error {
c.Println("usage string")
return nil
}
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
c.Flags().String("foo1", "", "")
assertNoErr(t, c.MarkFlagRequired("foo1"))
silentError := "failed flag parsing"
c.SetFlagErrorFunc(func(c *Command, err error) error {
c.Println(err)
c.Println(c.UsageString())
return errors.New(silentError)
})
unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag")
output, err := executeCommand(c, "--unknownflag")
got := err.Error()
checkStringContains(t, output, unknownFlagErrorMessage)
checkStringContains(t, output, c.UsageString())
if got != silentError {
t.Errorf("Expected error %s but got %s", silentError, got)
}
}
func TestPersistentRequiredFlags(t *testing.T) {
parent := &Command{Use: "parent", Run: emptyRun}
parent.PersistentFlags().String("foo1", "", "")