Refactor TakesArgs to use an interface for arg validation.

Fix some typos in README and comments.
Move arg validation to after flag validation so that the help flag is run first.
Pass the same args to ValidateArgs as the Run methods receive.
Update README.

Signed-off-by: Daniel Nephin <dnephin@gmail.com>
This commit is contained in:
Daniel Nephin 2016-05-07 13:24:05 -04:00 committed by Albert Nigmatzianov
parent d89c499964
commit f20b4e9c32
4 changed files with 160 additions and 84 deletions

View file

@ -467,36 +467,34 @@ A flag can also be assigned locally which will only apply to that specific comma
RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
``` ```
### Specify if you command takes arguments ## Positional and Custom Arguments
There are multiple options for how a command can handle unknown arguments which can be set in `TakesArgs` Validation of positional arguments can be specified using the `Args` field.
- `Legacy`
- `None`
- `Arbitrary`
- `ValidOnly`
`Legacy` (or default) the rules are as follows: The follow validators are built in:
- root commands with no subcommands can take arbitrary arguments
- root commands with subcommands will do subcommand validity checking
- subcommands will always accept arbitrary arguments and do no subsubcommand validity checking
`None` the command will be rejected if there are any left over arguments after parsing flags. - `NoArgs` - the command will report an error if there are any positional args.
- `ArbitraryArgs` - the command will accept any args.
- `OnlyValidArgs` - the command will report an error if there are any positional args that are not in the ValidArgs list.
- `MinimumNArgs(int)` - the command will report an error if there are not at least N positional args.
- `MaximumNArgs(int)` - the command will report an error if there are more than N positional args.
- `ExactArgs(int)` - the command will report an error if there are not exactly N positional args.
- `RangeArgs(min, max)` - the command will report an error if the number of args is not between the minimum and maximum number of expected args.
`Arbitrary` any additional values left after parsing flags will be passed in to your `Run` function. A custom validator can be provided like this:
`ValidOnly` you must define all valid (non-subcommand) arguments to your command. These are defined in a slice name ValidArgs. For example a command which only takes the argument "one" or "two" would be defined as:
```go ```go
var HugoCmd = &cobra.Command{
Use: "hugo", Args: func validColorArgs(cmd *cobra.Command, args []string) error {
Short: "Hugo is a very fast static site generator", if err := cli.RequiresMinArgs(1)(cmd, args); err != nil {
ValidArgs: []string{"one", "two", "three", "four"} return err
TakesArgs: cobra.ValidOnly }
Run: func(cmd *cobra.Command, args []string) { if myapp.IsValidColor(args[0]) {
// args will only have the values one, two, three, four return nil
// or the cmd.Execute() will fail. }
}, return fmt.Errorf("Invalid color specified: %s", args[0])
} }
``` ```
### Bind Flags with Config ### Bind Flags with Config
@ -517,6 +515,7 @@ when the `--author` flag is not provided by user.
More in [viper documentation](https://github.com/spf13/viper#working-with-flags). More in [viper documentation](https://github.com/spf13/viper#working-with-flags).
## Example ## Example
In the example below, we have defined three commands. Two are at the top level In the example below, we have defined three commands. Two are at the top level

98
args.go Normal file
View file

@ -0,0 +1,98 @@
package cobra
import (
"fmt"
)
type PositionalArgs func(cmd *Command, args []string) error
// Legacy arg validation has the following behaviour:
// - root commands with no subcommands can take arbitrary arguments
// - root commands with subcommands will do subcommand validity checking
// - subcommands will always accept arbitrary arguments
func legacyArgs(cmd *Command, args []string) error {
// no subcommand, always take args
if !cmd.HasSubCommands() {
return nil
}
// root command with subcommands, do subcommand checking
if !cmd.HasParent() && len(args) > 0 {
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
return nil
}
// NoArgs returns an error if any args are included
func NoArgs(cmd *Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath())
}
return nil
}
// OnlyValidArgs returns an error if any args are not in the list of ValidArgs
func OnlyValidArgs(cmd *Command, args []string) error {
if len(cmd.ValidArgs) > 0 {
for _, v := range args {
if !stringInSlice(v, cmd.ValidArgs) {
return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
}
}
return nil
}
func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// ArbitraryArgs never returns an error
func ArbitraryArgs(cmd *Command, args []string) error {
return nil
}
// MinimumNArgs returns an error if there is not at least N args
func MinimumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < n {
return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args))
}
return nil
}
}
// MaximumNArgs returns an error if there are more than N args
func MaximumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) > n {
return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args))
}
return nil
}
}
// ExactArgs returns an error if there are not exactly n args
func ExactArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) != n {
return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
}
return nil
}
}
// RangeArgs returns an error if the number of args is not within the expected range
func RangeArgs(min int, max int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < min || len(args) > max {
return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args))
}
return nil
}
}

View file

@ -36,6 +36,7 @@ var cmdHidden = &Command{
var cmdPrint = &Command{ var cmdPrint = &Command{
Use: "print [string to print]", Use: "print [string to print]",
Args: MinimumNArgs(1),
Short: "Print anything to the screen", Short: "Print anything to the screen",
Long: `an absolutely utterly useless command for testing.`, Long: `an absolutely utterly useless command for testing.`,
Run: func(cmd *Command, args []string) { Run: func(cmd *Command, args []string) {
@ -75,7 +76,7 @@ var cmdDeprecated = &Command{
Deprecated: "Please use echo instead", Deprecated: "Please use echo instead",
Run: func(cmd *Command, args []string) { Run: func(cmd *Command, args []string) {
}, },
TakesArgs: None, Args: NoArgs,
} }
var cmdTimes = &Command{ var cmdTimes = &Command{
@ -89,7 +90,7 @@ var cmdTimes = &Command{
Run: func(cmd *Command, args []string) { Run: func(cmd *Command, args []string) {
tt = args tt = args
}, },
TakesArgs: ValidOnly, Args: OnlyValidArgs,
ValidArgs: []string{"one", "two", "three", "four"}, ValidArgs: []string{"one", "two", "three", "four"},
} }
@ -103,10 +104,9 @@ var cmdRootNoRun = &Command{
} }
var cmdRootSameName = &Command{ var cmdRootSameName = &Command{
Use: "print", Use: "print",
Short: "Root with the same name as a subcommand", Short: "Root with the same name as a subcommand",
Long: "The root description for help", Long: "The root description for help",
TakesArgs: None,
} }
var cmdRootTakesArgs = &Command{ var cmdRootTakesArgs = &Command{
@ -116,7 +116,7 @@ var cmdRootTakesArgs = &Command{
Run: func(cmd *Command, args []string) { Run: func(cmd *Command, args []string) {
tr = args tr = args
}, },
TakesArgs: Arbitrary, Args: ArbitraryArgs,
} }
var cmdRootWithRun = &Command{ var cmdRootWithRun = &Command{
@ -477,6 +477,10 @@ func TestRootTakesNoArgs(t *testing.T) {
c.AddCommand(cmdPrint, cmdEcho) c.AddCommand(cmdPrint, cmdEcho)
result := simpleTester(c, "illegal") result := simpleTester(c, "illegal")
if result.Error == nil {
t.Fatal("Expected an error")
}
expectedError := `unknown command "illegal" for "print"` expectedError := `unknown command "illegal" for "print"`
if !strings.Contains(result.Error.Error(), expectedError) { if !strings.Contains(result.Error.Error(), expectedError) {
t.Errorf("exptected %v, got %v", expectedError, result.Error.Error()) t.Errorf("exptected %v, got %v", expectedError, result.Error.Error())
@ -493,7 +497,11 @@ func TestRootTakesArgs(t *testing.T) {
} }
func TestSubCmdTakesNoArgs(t *testing.T) { func TestSubCmdTakesNoArgs(t *testing.T) {
result := fullSetupTest("deprecated illegal") result := fullSetupTest("deprecated", "illegal")
if result.Error == nil {
t.Fatal("Expected an error")
}
expectedError := `unknown command "illegal" for "cobra-test deprecated"` expectedError := `unknown command "illegal" for "cobra-test deprecated"`
if !strings.Contains(result.Error.Error(), expectedError) { if !strings.Contains(result.Error.Error(), expectedError) {
@ -502,14 +510,18 @@ func TestSubCmdTakesNoArgs(t *testing.T) {
} }
func TestSubCmdTakesArgs(t *testing.T) { func TestSubCmdTakesArgs(t *testing.T) {
noRRSetupTest("echo times one two") noRRSetupTest("echo", "times", "one", "two")
if strings.Join(tt, " ") != "one two" { if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly") t.Error("Command didn't parse correctly")
} }
} }
func TestCmdOnlyValidArgs(t *testing.T) { func TestCmdOnlyValidArgs(t *testing.T) {
result := noRRSetupTest("echo times one two five") result := noRRSetupTest("echo", "times", "one", "two", "five")
if result.Error == nil {
t.Fatal("Expected an error")
}
expectedError := `invalid argument "five"` expectedError := `invalid argument "five"`
if !strings.Contains(result.Error.Error(), expectedError) { if !strings.Contains(result.Error.Error(), expectedError) {

View file

@ -27,15 +27,6 @@ import (
flag "github.com/spf13/pflag" flag "github.com/spf13/pflag"
) )
type Args int
const (
Legacy Args = iota
Arbitrary
ValidOnly
None
)
// Command is just that, a command for your application. // Command is just that, a command for your application.
// E.g. 'go run ...' - 'run' is the command. Cobra requires // E.g. 'go run ...' - 'run' is the command. Cobra requires
// you to define the usage and description as part of your command // you to define the usage and description as part of your command
@ -68,8 +59,8 @@ type Command struct {
// but accepted if entered manually. // but accepted if entered manually.
ArgAliases []string ArgAliases []string
// Does this command take arbitrary arguments // Expected arguments
TakesArgs Args Args PositionalArgs
// BashCompletionFunction is custom functions used by the bash autocompletion generator. // BashCompletionFunction is custom functions used by the bash autocompletion generator.
BashCompletionFunction string BashCompletionFunction string
@ -483,15 +474,6 @@ func argsMinusFirstX(args []string, x string) []string {
return args return args
} }
func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// Find the target command given the args and command tree // Find the target command given the args and command tree
// Meant to be run on the highest node. Only searches down. // Meant to be run on the highest node. Only searches down.
func (c *Command) Find(args []string) (*Command, []string, error) { func (c *Command) Find(args []string) (*Command, []string, error) {
@ -533,39 +515,13 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
} }
commandFound, a := innerfind(c, args) commandFound, a := innerfind(c, args)
argsWOflags := stripFlags(a, commandFound) if commandFound.Args == nil {
return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound))
// "Legacy" has some 'odd' characteristics.
// - root commands with no subcommands can take arbitrary arguments
// - root commands with subcommands will do subcommand validity checking
// - subcommands will always accept arbitrary arguments
if commandFound.TakesArgs == Legacy {
// no subcommand, always take args
if !commandFound.HasSubCommands() {
return commandFound, a, nil
}
// root command with subcommands, do subcommand checking
if commandFound == c && len(argsWOflags) > 0 {
return commandFound, a, fmt.Errorf("unknown command %q for %q%s", argsWOflags[0], commandFound.CommandPath(), c.findSuggestions(argsWOflags))
}
return commandFound, a, nil
}
if commandFound.TakesArgs == None && len(argsWOflags) > 0 {
return commandFound, a, fmt.Errorf("unknown command %q for %q", argsWOflags[0], commandFound.CommandPath())
}
if commandFound.TakesArgs == ValidOnly && len(commandFound.ValidArgs) > 0 {
for _, v := range argsWOflags {
if !stringInSlice(v, commandFound.ValidArgs) {
return commandFound, a, fmt.Errorf("invalid argument %q for %q%s", v, commandFound.CommandPath(), c.findSuggestions(argsWOflags))
}
}
} }
return commandFound, a, nil return commandFound, a, nil
} }
func (c *Command) findSuggestions(argsWOflags []string) string { func (c *Command) findSuggestions(arg string) string {
if c.DisableSuggestions { if c.DisableSuggestions {
return "" return ""
} }
@ -573,7 +529,7 @@ func (c *Command) findSuggestions(argsWOflags []string) string {
c.SuggestionsMinimumDistance = 2 c.SuggestionsMinimumDistance = 2
} }
suggestionsString := "" suggestionsString := ""
if suggestions := c.SuggestionsFor(argsWOflags[0]); len(suggestions) > 0 { if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 {
suggestionsString += "\n\nDid you mean this?\n" suggestionsString += "\n\nDid you mean this?\n"
for _, s := range suggestions { for _, s := range suggestions {
suggestionsString += fmt.Sprintf("\t%v\n", s) suggestionsString += fmt.Sprintf("\t%v\n", s)
@ -666,6 +622,10 @@ func (c *Command) execute(a []string) (err error) {
argWoFlags = a argWoFlags = a
} }
if err := c.ValidateArgs(argWoFlags); err != nil {
return err
}
for p := c; p != nil; p = p.Parent() { for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil { if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
@ -789,6 +749,13 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
return cmd, err return cmd, err
} }
func (c *Command) ValidateArgs(args []string) error {
if c.Args == nil {
return nil
}
return c.Args(c, args)
}
// InitDefaultHelpFlag adds default help flag to c. // InitDefaultHelpFlag adds default help flag to c.
// It is called automatically by executing the c or by calling help and usage. // It is called automatically by executing the c or by calling help and usage.
// If c already has help flag, it will do nothing. // If c already has help flag, it will do nothing.