mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Add support for context.Context
This commit is contained in:
parent
21cab29ef9
commit
0da0687426
2 changed files with 95 additions and 2 deletions
26
command.go
26
command.go
|
@ -17,6 +17,7 @@ package cobra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -146,6 +147,8 @@ type Command struct {
|
||||||
// FParseErrWhitelist flag parse errors to be ignored
|
// FParseErrWhitelist flag parse errors to be ignored
|
||||||
FParseErrWhitelist FParseErrWhitelist
|
FParseErrWhitelist FParseErrWhitelist
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
// commands is the list of commands supported by this program.
|
// commands is the list of commands supported by this program.
|
||||||
commands []*Command
|
commands []*Command
|
||||||
// parent is a parent command for this command.
|
// parent is a parent command for this command.
|
||||||
|
@ -205,6 +208,12 @@ type Command struct {
|
||||||
errWriter io.Writer
|
errWriter io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Context returns underlying command context. If command wasn't
|
||||||
|
// executed with ExecuteContext Context returns Background context.
|
||||||
|
func (c *Command) Context() context.Context {
|
||||||
|
return c.ctx
|
||||||
|
}
|
||||||
|
|
||||||
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
|
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
|
||||||
// particularly useful when testing.
|
// particularly useful when testing.
|
||||||
func (c *Command) SetArgs(a []string) {
|
func (c *Command) SetArgs(a []string) {
|
||||||
|
@ -862,6 +871,13 @@ func (c *Command) preRun() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteContext is the same as Execute(), but sets the ctx on the command.
|
||||||
|
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
|
||||||
|
func (c *Command) ExecuteContext(ctx context.Context) error {
|
||||||
|
c.ctx = ctx
|
||||||
|
return c.Execute()
|
||||||
|
}
|
||||||
|
|
||||||
// Execute uses the args (os.Args[1:] by default)
|
// Execute uses the args (os.Args[1:] by default)
|
||||||
// and run through the command tree finding appropriate matches
|
// and run through the command tree finding appropriate matches
|
||||||
// for commands and then corresponding flags.
|
// for commands and then corresponding flags.
|
||||||
|
@ -872,6 +888,10 @@ func (c *Command) Execute() error {
|
||||||
|
|
||||||
// ExecuteC executes the command.
|
// ExecuteC executes the command.
|
||||||
func (c *Command) ExecuteC() (cmd *Command, err error) {
|
func (c *Command) ExecuteC() (cmd *Command, err error) {
|
||||||
|
if c.ctx == nil {
|
||||||
|
c.ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
// Regardless of what command execute is called on, run on Root only
|
// Regardless of what command execute is called on, run on Root only
|
||||||
if c.HasParent() {
|
if c.HasParent() {
|
||||||
return c.Root().ExecuteC()
|
return c.Root().ExecuteC()
|
||||||
|
@ -916,6 +936,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
|
||||||
cmd.commandCalledAs.name = cmd.Name()
|
cmd.commandCalledAs.name = cmd.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We have to pass global context to children command
|
||||||
|
// if context is present on the parent command.
|
||||||
|
if cmd.ctx == nil {
|
||||||
|
cmd.ctx = c.ctx
|
||||||
|
}
|
||||||
|
|
||||||
err = cmd.execute(flags)
|
err = cmd.execute(flags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Always show help if requested, even if SilenceErrors is in
|
// Always show help if requested, even if SilenceErrors is in
|
||||||
|
|
|
@ -2,6 +2,7 @@ package cobra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -18,6 +19,16 @@ func executeCommand(root *Command, args ...string) (output string, err error) {
|
||||||
return output, err
|
return output, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func executeCommandWithContext(ctx context.Context, root *Command, args ...string) (output string, err error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
root.SetOutput(buf)
|
||||||
|
root.SetArgs(args)
|
||||||
|
|
||||||
|
err = root.ExecuteContext(ctx)
|
||||||
|
|
||||||
|
return buf.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) {
|
func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
root.SetOutput(buf)
|
root.SetOutput(buf)
|
||||||
|
@ -135,6 +146,62 @@ func TestSubcommandExecuteC(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteContext(t *testing.T) {
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
|
ctxRun := func(cmd *Command, args []string) {
|
||||||
|
if cmd.Context() != ctx {
|
||||||
|
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
|
||||||
|
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
|
||||||
|
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}
|
||||||
|
|
||||||
|
childCmd.AddCommand(granchildCmd)
|
||||||
|
rootCmd.AddCommand(childCmd)
|
||||||
|
|
||||||
|
if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil {
|
||||||
|
t.Errorf("Root command must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil {
|
||||||
|
t.Errorf("Subcommand must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := executeCommandWithContext(ctx, rootCmd, "child", "grandchild"); err != nil {
|
||||||
|
t.Errorf("Command child must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_NoContext(t *testing.T) {
|
||||||
|
run := func(cmd *Command, args []string) {
|
||||||
|
if cmd.Context() != context.Background() {
|
||||||
|
t.Errorf("Command %s must have background context", cmd.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd := &Command{Use: "root", Run: run, PreRun: run}
|
||||||
|
childCmd := &Command{Use: "child", Run: run, PreRun: run}
|
||||||
|
granchildCmd := &Command{Use: "grandchild", Run: run, PreRun: run}
|
||||||
|
|
||||||
|
childCmd.AddCommand(granchildCmd)
|
||||||
|
rootCmd.AddCommand(childCmd)
|
||||||
|
|
||||||
|
if _, err := executeCommand(rootCmd, ""); err != nil {
|
||||||
|
t.Errorf("Root command must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := executeCommand(rootCmd, "child"); err != nil {
|
||||||
|
t.Errorf("Subcommand must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := executeCommand(rootCmd, "child", "grandchild"); err != nil {
|
||||||
|
t.Errorf("Command child must not fail: %+v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRootUnknownCommandSilenced(t *testing.T) {
|
func TestRootUnknownCommandSilenced(t *testing.T) {
|
||||||
rootCmd := &Command{Use: "root", Run: emptyRun}
|
rootCmd := &Command{Use: "root", Run: emptyRun}
|
||||||
rootCmd.SilenceErrors = true
|
rootCmd.SilenceErrors = true
|
||||||
|
|
Loading…
Reference in a new issue