diff --git a/cobra_test.go b/cobra_test.go index 7cb4917d..625ed28d 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -68,6 +68,7 @@ var cmdDeprecated = &Command{ Deprecated: "Please use echo instead", Run: func(cmd *Command, args []string) { }, + TakesArgs: None, } var cmdTimes = &Command{ @@ -80,6 +81,7 @@ var cmdTimes = &Command{ Run: func(cmd *Command, args []string) { tt = args }, + TakesArgs: Arbitrary, } var cmdRootNoRun = &Command{ @@ -92,9 +94,20 @@ var cmdRootNoRun = &Command{ } var cmdRootSameName = &Command{ - Use: "print", - Short: "Root with the same name as a subcommand", - Long: "The root description for help", + Use: "print", + Short: "Root with the same name as a subcommand", + Long: "The root description for help", + TakesArgs: None, +} + +var cmdRootTakesArgs = &Command{ + Use: "root-with-args [random args]", + Short: "The root can run it's own function and takes args!", + Long: "The root description for help, and some args", + Run: func(cmd *Command, args []string) { + tr = args + }, + TakesArgs: Arbitrary, } var cmdRootWithRun = &Command{ @@ -396,6 +409,42 @@ func TestGrandChildSameName(t *testing.T) { } } +func TestRootTakesNoArgs(t *testing.T) { + c := initializeWithSameName() + c.AddCommand(cmdPrint, cmdEcho) + result := simpleTester(c, "illegal") + + expectedError := `unknown command "illegal" for "print"` + if !strings.Contains(result.Error.Error(), expectedError) { + t.Errorf("exptected %v, got %v", expectedError, result.Error.Error()) + } +} + +func TestRootTakesArgs(t *testing.T) { + c := cmdRootTakesArgs + result := simpleTester(c, "legal") + + if result.Error != nil { + t.Errorf("expected no error, but got %v", result.Error) + } +} + +func TestSubCmdTakesNoArgs(t *testing.T) { + result := fullSetupTest("deprecated illegal") + + expectedError := `unknown command "illegal" for "cobra-test deprecated"` + if !strings.Contains(result.Error.Error(), expectedError) { + t.Errorf("exptected %v, got %v", expectedError, result.Error.Error()) + } +} + +func TestSubCmdTakesArgs(t *testing.T) { + noRRSetupTest("echo times one two") + if strings.Join(tt, " ") != "one two" { + t.Error("Command didn't parse correctly") + } +} + func TestFlagLong(t *testing.T) { noRRSetupTest("echo --intone=13 something here") diff --git a/command.go b/command.go index 74565c2b..fff21435 100644 --- a/command.go +++ b/command.go @@ -28,6 +28,14 @@ import ( flag "github.com/spf13/pflag" ) +type Args int + +const ( + Legacy Args = iota + Arbitrary + None +) + // Command is just that, a command for your application. // eg. 'go run' ... 'run' is the command. Cobra requires // you to define the usage and description as part of your command @@ -47,6 +55,8 @@ type Command struct { Example string // List of all valid non-flag arguments, used for bash completions *TODO* actually validate these ValidArgs []string + // Does this command take arbitrary arguments + TakesArgs Args // Custom functions used by the bash autocompletion generator BashCompletionFunction string // Is this command deprecated and should print this string when used? @@ -416,12 +426,23 @@ func (c *Command) Find(args []string) (*Command, []string, error) { commandFound, a := innerfind(c, args) argsWOflags := stripFlags(a, commandFound) - // no subcommand, always take args - if !commandFound.HasSubCommands() { + // "Legacy" has some 'odd' characteristics. + // - root commands with no subcommands can take arbitrary arguments + // - root commands with subcommands will do subcommand validity checking + // - subcommands will always accept arbitrary arguments + if commandFound.TakesArgs == Legacy { + // no subcommand, always take args + if !commandFound.HasSubCommands() { + return commandFound, a, nil + } + // root command with subcommands, do subcommand checking + if commandFound == c && len(argsWOflags) > 0 { + return commandFound, a, fmt.Errorf("unknown command %q for %q", argsWOflags[0], commandFound.CommandPath()) + } return commandFound, a, nil } - // root command with subcommands, do subcommand checking - if commandFound == c && len(argsWOflags) > 0 { + + if commandFound.TakesArgs == None && len(argsWOflags) > 0 { return commandFound, a, fmt.Errorf("unknown command %q for %q", argsWOflags[0], commandFound.CommandPath()) }