diff --git a/command.go b/command.go index b6f8f4b1..f38c6c40 100644 --- a/command.go +++ b/command.go @@ -1058,7 +1058,7 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { // Regardless of what command execute is called on, run on Root only if c.HasParent() { - return c.Root().ExecuteC() + return c.Root().ExecuteContextC(c.ctx) } // windows hook @@ -1108,11 +1108,8 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd.commandCalledAs.name = cmd.Name() } - // We have to pass global context to children command - // if context is present on the parent command. - if cmd.ctx == nil { - cmd.ctx = c.ctx - } + // Pass context of root command to child command. + cmd.ctx = c.ctx err = cmd.execute(flags) if err != nil { diff --git a/command_test.go b/command_test.go index db336922..5a04e00c 100644 --- a/command_test.go +++ b/command_test.go @@ -232,6 +232,96 @@ func TestExecuteContextC(t *testing.T) { } } +// This tests that the context passed to the root command propagates to its children +// not only on the first execution but also subsequent calls. +// Calling the same command multiple times is common when testing cobra applications. +func TestExecuteContextMultiple(t *testing.T) { + var key string + + // Define unique contexts so we can tell them apart below. + ctxs := []context.Context{ + context.WithValue(context.Background(), &key, "1"), + context.WithValue(context.Background(), &key, "2"), + } + + // Shared reference to the context in the current iteration. + var currentCtx context.Context + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != currentCtx { + t.Errorf("Command %q must have context with value %s", cmd.Use, currentCtx.Value(&key)) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + for i := 0; i < 2; i++ { + currentCtx = ctxs[i] + + if _, err := executeCommandWithContext(currentCtx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(currentCtx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(currentCtx, rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } + } +} + +// This tests that the context passed to a subcommand propagates to the root. +// If the entry point happens to be different from the root command, the +// context should still propagate throughout the execution. +func TestExecuteContextOnSubcommand(t *testing.T) { + var key string + + // Define unique contexts so we can tell them apart below. + ctxs := []context.Context{ + context.WithValue(context.Background(), &key, "1"), + context.WithValue(context.Background(), &key, "2"), + context.WithValue(context.Background(), &key, "3"), + } + + // Shared reference to the context in the current iteration. + var currentCtx context.Context + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != currentCtx { + t.Errorf("Command %q must have context with value %s", cmd.Use, currentCtx.Value(&key)) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + currentCtx = ctxs[0] + if _, err := executeCommandWithContext(currentCtx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + currentCtx = ctxs[1] + if _, err := executeCommandWithContext(currentCtx, childCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + currentCtx = ctxs[2] + if _, err := executeCommandWithContext(currentCtx, granchildCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + func TestExecute_NoContext(t *testing.T) { run := func(cmd *Command, args []string) { if cmd.Context() != context.Background() {