diff --git a/README.md b/README.md index 02075833..3dbb1871 100644 --- a/README.md +++ b/README.md @@ -467,36 +467,34 @@ A flag can also be assigned locally which will only apply to that specific comma RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") ``` -### Specify if you command takes arguments +## Positional and Custom Arguments -There are multiple options for how a command can handle unknown arguments which can be set in `TakesArgs` -- `Legacy` -- `None` -- `Arbitrary` -- `ValidOnly` +Validation of positional arguments can be specified using the `Args` field. -`Legacy` (or default) the rules are as follows: -- root commands with no subcommands can take arbitrary arguments -- root commands with subcommands will do subcommand validity checking -- subcommands will always accept arbitrary arguments and do no subsubcommand validity checking +The follow validators are built in: -`None` the command will be rejected if there are any left over arguments after parsing flags. +- `NoArgs` - the command will report an error if there are any positional args. +- `ArbitraryArgs` - the command will accept any args. +- `OnlyValidArgs` - the command will report an error if there are any positional args that are not in the ValidArgs list. +- `MinimumNArgs(int)` - the command will report an error if there are not at least N positional args. +- `MaximumNArgs(int)` - the command will report an error if there are more than N positional args. +- `ExactArgs(int)` - the command will report an error if there are not exactly N positional args. +- `RangeArgs(min, max)` - the command will report an error if the number of args is not between the minimum and maximum number of expected args. -`Arbitrary` any additional values left after parsing flags will be passed in to your `Run` function. - -`ValidOnly` you must define all valid (non-subcommand) arguments to your command. These are defined in a slice name ValidArgs. For example a command which only takes the argument "one" or "two" would be defined as: +A custom validator can be provided like this: ```go -var HugoCmd = &cobra.Command{ - Use: "hugo", - Short: "Hugo is a very fast static site generator", - ValidArgs: []string{"one", "two", "three", "four"} - TakesArgs: cobra.ValidOnly - Run: func(cmd *cobra.Command, args []string) { - // args will only have the values one, two, three, four - // or the cmd.Execute() will fail. - }, - } + +Args: func validColorArgs(cmd *cobra.Command, args []string) error { + if err := cli.RequiresMinArgs(1)(cmd, args); err != nil { + return err + } + if myapp.IsValidColor(args[0]) { + return nil + } + return fmt.Errorf("Invalid color specified: %s", args[0]) +} + ``` ### Bind Flags with Config @@ -517,6 +515,7 @@ when the `--author` flag is not provided by user. More in [viper documentation](https://github.com/spf13/viper#working-with-flags). + ## Example In the example below, we have defined three commands. Two are at the top level diff --git a/args.go b/args.go new file mode 100644 index 00000000..94a6ca27 --- /dev/null +++ b/args.go @@ -0,0 +1,98 @@ +package cobra + +import ( + "fmt" +) + +type PositionalArgs func(cmd *Command, args []string) error + +// Legacy arg validation has the following behaviour: +// - root commands with no subcommands can take arbitrary arguments +// - root commands with subcommands will do subcommand validity checking +// - subcommands will always accept arbitrary arguments +func legacyArgs(cmd *Command, args []string) error { + // no subcommand, always take args + if !cmd.HasSubCommands() { + return nil + } + + // root command with subcommands, do subcommand checking + if !cmd.HasParent() && len(args) > 0 { + return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0])) + } + return nil +} + +// NoArgs returns an error if any args are included +func NoArgs(cmd *Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath()) + } + return nil +} + +// OnlyValidArgs returns an error if any args are not in the list of ValidArgs +func OnlyValidArgs(cmd *Command, args []string) error { + if len(cmd.ValidArgs) > 0 { + for _, v := range args { + if !stringInSlice(v, cmd.ValidArgs) { + return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0])) + } + } + } + return nil +} + +func stringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// ArbitraryArgs never returns an error +func ArbitraryArgs(cmd *Command, args []string) error { + return nil +} + +// MinimumNArgs returns an error if there is not at least N args +func MinimumNArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) < n { + return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args)) + } + return nil + } +} + +// MaximumNArgs returns an error if there are more than N args +func MaximumNArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) > n { + return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args)) + } + return nil + } +} + +// ExactArgs returns an error if there are not exactly n args +func ExactArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) != n { + return fmt.Errorf("accepts %d arg(s), received %d", n, len(args)) + } + return nil + } +} + +// RangeArgs returns an error if the number of args is not within the expected range +func RangeArgs(min int, max int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) < min || len(args) > max { + return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args)) + } + return nil + } +} diff --git a/cobra_test.go b/cobra_test.go index 89dfb3c5..1706eae2 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -36,6 +36,7 @@ var cmdHidden = &Command{ var cmdPrint = &Command{ Use: "print [string to print]", + Args: MinimumNArgs(1), Short: "Print anything to the screen", Long: `an absolutely utterly useless command for testing.`, Run: func(cmd *Command, args []string) { @@ -75,7 +76,7 @@ var cmdDeprecated = &Command{ Deprecated: "Please use echo instead", Run: func(cmd *Command, args []string) { }, - TakesArgs: None, + Args: NoArgs, } var cmdTimes = &Command{ @@ -89,7 +90,7 @@ var cmdTimes = &Command{ Run: func(cmd *Command, args []string) { tt = args }, - TakesArgs: ValidOnly, + Args: OnlyValidArgs, ValidArgs: []string{"one", "two", "three", "four"}, } @@ -103,10 +104,9 @@ var cmdRootNoRun = &Command{ } var cmdRootSameName = &Command{ - Use: "print", - Short: "Root with the same name as a subcommand", - Long: "The root description for help", - TakesArgs: None, + Use: "print", + Short: "Root with the same name as a subcommand", + Long: "The root description for help", } var cmdRootTakesArgs = &Command{ @@ -116,7 +116,7 @@ var cmdRootTakesArgs = &Command{ Run: func(cmd *Command, args []string) { tr = args }, - TakesArgs: Arbitrary, + Args: ArbitraryArgs, } var cmdRootWithRun = &Command{ @@ -477,6 +477,10 @@ func TestRootTakesNoArgs(t *testing.T) { c.AddCommand(cmdPrint, cmdEcho) result := simpleTester(c, "illegal") + if result.Error == nil { + t.Fatal("Expected an error") + } + expectedError := `unknown command "illegal" for "print"` if !strings.Contains(result.Error.Error(), expectedError) { t.Errorf("exptected %v, got %v", expectedError, result.Error.Error()) @@ -493,7 +497,11 @@ func TestRootTakesArgs(t *testing.T) { } func TestSubCmdTakesNoArgs(t *testing.T) { - result := fullSetupTest("deprecated illegal") + result := fullSetupTest("deprecated", "illegal") + + if result.Error == nil { + t.Fatal("Expected an error") + } expectedError := `unknown command "illegal" for "cobra-test deprecated"` if !strings.Contains(result.Error.Error(), expectedError) { @@ -502,14 +510,18 @@ func TestSubCmdTakesNoArgs(t *testing.T) { } func TestSubCmdTakesArgs(t *testing.T) { - noRRSetupTest("echo times one two") + noRRSetupTest("echo", "times", "one", "two") if strings.Join(tt, " ") != "one two" { t.Error("Command didn't parse correctly") } } func TestCmdOnlyValidArgs(t *testing.T) { - result := noRRSetupTest("echo times one two five") + result := noRRSetupTest("echo", "times", "one", "two", "five") + + if result.Error == nil { + t.Fatal("Expected an error") + } expectedError := `invalid argument "five"` if !strings.Contains(result.Error.Error(), expectedError) { diff --git a/command.go b/command.go index 131f01c4..4f65d770 100644 --- a/command.go +++ b/command.go @@ -27,15 +27,6 @@ import ( flag "github.com/spf13/pflag" ) -type Args int - -const ( - Legacy Args = iota - Arbitrary - ValidOnly - None -) - // Command is just that, a command for your application. // E.g. 'go run ...' - 'run' is the command. Cobra requires // you to define the usage and description as part of your command @@ -68,8 +59,8 @@ type Command struct { // but accepted if entered manually. ArgAliases []string - // Does this command take arbitrary arguments - TakesArgs Args + // Expected arguments + Args PositionalArgs // BashCompletionFunction is custom functions used by the bash autocompletion generator. BashCompletionFunction string @@ -483,15 +474,6 @@ func argsMinusFirstX(args []string, x string) []string { return args } -func stringInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - // Find the target command given the args and command tree // Meant to be run on the highest node. Only searches down. func (c *Command) Find(args []string) (*Command, []string, error) { @@ -533,39 +515,13 @@ func (c *Command) Find(args []string) (*Command, []string, error) { } commandFound, a := innerfind(c, args) - argsWOflags := stripFlags(a, commandFound) - - // "Legacy" has some 'odd' characteristics. - // - root commands with no subcommands can take arbitrary arguments - // - root commands with subcommands will do subcommand validity checking - // - subcommands will always accept arbitrary arguments - if commandFound.TakesArgs == Legacy { - // no subcommand, always take args - if !commandFound.HasSubCommands() { - return commandFound, a, nil - } - // root command with subcommands, do subcommand checking - if commandFound == c && len(argsWOflags) > 0 { - return commandFound, a, fmt.Errorf("unknown command %q for %q%s", argsWOflags[0], commandFound.CommandPath(), c.findSuggestions(argsWOflags)) - } - return commandFound, a, nil - } - - if commandFound.TakesArgs == None && len(argsWOflags) > 0 { - return commandFound, a, fmt.Errorf("unknown command %q for %q", argsWOflags[0], commandFound.CommandPath()) - } - - if commandFound.TakesArgs == ValidOnly && len(commandFound.ValidArgs) > 0 { - for _, v := range argsWOflags { - if !stringInSlice(v, commandFound.ValidArgs) { - return commandFound, a, fmt.Errorf("invalid argument %q for %q%s", v, commandFound.CommandPath(), c.findSuggestions(argsWOflags)) - } - } + if commandFound.Args == nil { + return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound)) } return commandFound, a, nil } -func (c *Command) findSuggestions(argsWOflags []string) string { +func (c *Command) findSuggestions(arg string) string { if c.DisableSuggestions { return "" } @@ -573,7 +529,7 @@ func (c *Command) findSuggestions(argsWOflags []string) string { c.SuggestionsMinimumDistance = 2 } suggestionsString := "" - if suggestions := c.SuggestionsFor(argsWOflags[0]); len(suggestions) > 0 { + if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 { suggestionsString += "\n\nDid you mean this?\n" for _, s := range suggestions { suggestionsString += fmt.Sprintf("\t%v\n", s) @@ -666,6 +622,10 @@ func (c *Command) execute(a []string) (err error) { argWoFlags = a } + if err := c.ValidateArgs(argWoFlags); err != nil { + return err + } + for p := c; p != nil; p = p.Parent() { if p.PersistentPreRunE != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil { @@ -789,6 +749,13 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { return cmd, err } +func (c *Command) ValidateArgs(args []string) error { + if c.Args == nil { + return nil + } + return c.Args(c, args) +} + // InitDefaultHelpFlag adds default help flag to c. // It is called automatically by executing the c or by calling help and usage. // If c already has help flag, it will do nothing.