diff --git a/cobra_test.go b/cobra_test.go index 9a15fdb5..d89659b0 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -248,6 +248,18 @@ func simpleTester(c *Command, input string) resulter { return resulter{err, output, c} } +func simpleTesterC(c *Command, input string) resulter { + buf := new(bytes.Buffer) + // Testing flag with invalid input + c.SetOutput(buf) + c.SetArgs(strings.Split(input, " ")) + + cmd, err := c.ExecuteC() + output := buf.String() + + return resulter{err, output, cmd} +} + func fullTester(c *Command, input string) resulter { buf := new(bytes.Buffer) // Testing flag with invalid input @@ -561,6 +573,41 @@ func TestInvalidSubcommandFlags(t *testing.T) { } +func TestSubcommandExecuteC(t *testing.T) { + cmd := initializeWithRootCmd() + double := &Command{ + Use: "double message", + Run: func(c *Command, args []string) { + msg := strings.Join(args, " ") + c.Println(msg, msg) + }, + } + + echo := &Command{ + Use: "echo message", + Run: func(c *Command, args []string) { + msg := strings.Join(args, " ") + c.Println(msg, msg) + }, + } + + cmd.AddCommand(double, echo) + + result := simpleTesterC(cmd, "double hello world") + checkResultContains(t, result, "hello world hello world") + + if result.Command.Name() != "double" { + t.Errorf("invalid cmd returned from ExecuteC: should be 'double' but got %s", result.Command.Name()) + } + + result = simpleTesterC(cmd, "echo msg to be echoed") + checkResultContains(t, result, "msg to be echoed") + + if result.Command.Name() != "echo" { + t.Errorf("invalid cmd returned from ExecuteC: should be 'echo' but got %s", result.Command.Name()) + } +} + func TestSubcommandArgEvaluation(t *testing.T) { cmd := initializeWithRootCmd() diff --git a/command.go b/command.go index d39bfeb9..f5fa34e2 100644 --- a/command.go +++ b/command.go @@ -614,11 +614,16 @@ func (c *Command) errorMsgFromParse() string { // Call execute to use the args (os.Args[1:] by default) // and run through the command tree finding appropriate matches // for commands and then corresponding flags. -func (c *Command) Execute() (err error) { +func (c *Command) Execute() error { + _, err := c.ExecuteC() + return err +} + +func (c *Command) ExecuteC() (cmd *Command, err error) { // Regardless of what command execute is called on, run on Root only if c.HasParent() { - return c.Root().Execute() + return c.Root().ExecuteC() } if EnableWindowsMouseTrap && runtime.GOOS == "windows" { @@ -652,9 +657,8 @@ func (c *Command) Execute() (err error) { c.Println("Error:", err.Error()) c.Printf("Run '%v --help' for usage.\n", c.CommandPath()) } - return err + return c, err } - err = cmd.execute(flags) if err != nil { // If root command has SilentErrors flagged, @@ -662,7 +666,7 @@ func (c *Command) Execute() (err error) { if !cmd.SilenceErrors && !c.SilenceErrors { if err == flag.ErrHelp { cmd.HelpFunc()(cmd, args) - return nil + return cmd, nil } c.Println("Error:", err.Error()) } @@ -672,9 +676,9 @@ func (c *Command) Execute() (err error) { if !cmd.SilenceUsage && !c.SilenceUsage { c.Println(cmd.UsageString()) } - return err + return cmd, err } - return + return cmd, nil } func (c *Command) initHelpFlag() {