mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Pass context to completion (#1265)
This commit is contained in:
parent
7223a997c8
commit
6d00909120
4 changed files with 94 additions and 1 deletions
11
command.go
11
command.go
|
@ -887,7 +887,8 @@ func (c *Command) preRun() {
|
|||
}
|
||||
|
||||
// ExecuteContext is the same as Execute(), but sets the ctx on the command.
|
||||
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
|
||||
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
|
||||
// functions.
|
||||
func (c *Command) ExecuteContext(ctx context.Context) error {
|
||||
c.ctx = ctx
|
||||
return c.Execute()
|
||||
|
@ -901,6 +902,14 @@ func (c *Command) Execute() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// ExecuteContextC is the same as ExecuteC(), but sets the ctx on the command.
|
||||
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
|
||||
// functions.
|
||||
func (c *Command) ExecuteContextC(ctx context.Context) (*Command, error) {
|
||||
c.ctx = ctx
|
||||
return c.ExecuteC()
|
||||
}
|
||||
|
||||
// ExecuteC executes the command.
|
||||
func (c *Command) ExecuteC() (cmd *Command, err error) {
|
||||
if c.ctx == nil {
|
||||
|
|
|
@ -42,6 +42,17 @@ func executeCommandC(root *Command, args ...string) (c *Command, output string,
|
|||
return c, buf.String(), err
|
||||
}
|
||||
|
||||
func executeCommandWithContextC(ctx context.Context, root *Command, args ...string) (c *Command, output string, err error) {
|
||||
buf := new(bytes.Buffer)
|
||||
root.SetOut(buf)
|
||||
root.SetErr(buf)
|
||||
root.SetArgs(args)
|
||||
|
||||
c, err = root.ExecuteContextC(ctx)
|
||||
|
||||
return c, buf.String(), err
|
||||
}
|
||||
|
||||
func resetCommandLineFlagSet() {
|
||||
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
|
||||
}
|
||||
|
@ -178,6 +189,35 @@ func TestExecuteContext(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecuteContextC(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
|
||||
ctxRun := func(cmd *Command, args []string) {
|
||||
if cmd.Context() != ctx {
|
||||
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
|
||||
}
|
||||
}
|
||||
|
||||
rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
|
||||
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
|
||||
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}
|
||||
|
||||
childCmd.AddCommand(granchildCmd)
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
if _, _, err := executeCommandWithContextC(ctx, rootCmd, ""); err != nil {
|
||||
t.Errorf("Root command must not fail: %+v", err)
|
||||
}
|
||||
|
||||
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child"); err != nil {
|
||||
t.Errorf("Subcommand must not fail: %+v", err)
|
||||
}
|
||||
|
||||
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child", "grandchild"); err != nil {
|
||||
t.Errorf("Command child must not fail: %+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_NoContext(t *testing.T) {
|
||||
run := func(cmd *Command, args []string) {
|
||||
if cmd.Context() != context.Background() {
|
||||
|
|
|
@ -221,6 +221,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
|
|||
// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
|
||||
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
|
||||
}
|
||||
finalCmd.ctx = c.ctx
|
||||
|
||||
// Check if we are doing flag value completion before parsing the flags.
|
||||
// This is important because if we are completing a flag value, we need to also
|
||||
|
|
|
@ -2,6 +2,7 @@ package cobra
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
@ -1203,6 +1204,48 @@ func TestFlagDirFilterCompletionInGo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidArgsFuncCmdContext(t *testing.T) {
|
||||
validArgsFunc := func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
|
||||
ctx := cmd.Context()
|
||||
|
||||
if ctx == nil {
|
||||
t.Error("Received nil context in completion func")
|
||||
} else if ctx.Value("testKey") != "123" {
|
||||
t.Error("Received invalid context")
|
||||
}
|
||||
|
||||
return nil, ShellCompDirectiveDefault
|
||||
}
|
||||
|
||||
rootCmd := &Command{
|
||||
Use: "root",
|
||||
Run: emptyRun,
|
||||
}
|
||||
childCmd := &Command{
|
||||
Use: "childCmd",
|
||||
ValidArgsFunction: validArgsFunc,
|
||||
Run: emptyRun,
|
||||
}
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
//nolint:golint,staticcheck // We can safely use a basic type as key in tests.
|
||||
ctx := context.WithValue(context.Background(), "testKey", "123")
|
||||
|
||||
// Test completing an empty string on the childCmd
|
||||
_, output, err := executeCommandWithContextC(ctx, rootCmd, ShellCompNoDescRequestCmd, "childCmd", "")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := strings.Join([]string{
|
||||
":0",
|
||||
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")
|
||||
|
||||
if output != expected {
|
||||
t.Errorf("expected: %q, got: %q", expected, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidArgsFuncSingleCmd(t *testing.T) {
|
||||
rootCmd := &Command{
|
||||
Use: "root",
|
||||
|
|
Loading…
Reference in a new issue