Pass context to completion (#1265)

This commit is contained in:
Lukas Malkmus 2021-05-03 18:33:57 +02:00 committed by GitHub
parent 7223a997c8
commit 6d00909120
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 1 deletions

View file

@ -887,7 +887,8 @@ func (c *Command) preRun() {
} }
// ExecuteContext is the same as Execute(), but sets the ctx on the command. // ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
// functions.
func (c *Command) ExecuteContext(ctx context.Context) error { func (c *Command) ExecuteContext(ctx context.Context) error {
c.ctx = ctx c.ctx = ctx
return c.Execute() return c.Execute()
@ -901,6 +902,14 @@ func (c *Command) Execute() error {
return err return err
} }
// ExecuteContextC is the same as ExecuteC(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
// functions.
func (c *Command) ExecuteContextC(ctx context.Context) (*Command, error) {
c.ctx = ctx
return c.ExecuteC()
}
// ExecuteC executes the command. // ExecuteC executes the command.
func (c *Command) ExecuteC() (cmd *Command, err error) { func (c *Command) ExecuteC() (cmd *Command, err error) {
if c.ctx == nil { if c.ctx == nil {

View file

@ -42,6 +42,17 @@ func executeCommandC(root *Command, args ...string) (c *Command, output string,
return c, buf.String(), err return c, buf.String(), err
} }
func executeCommandWithContextC(ctx context.Context, root *Command, args ...string) (c *Command, output string, err error) {
buf := new(bytes.Buffer)
root.SetOut(buf)
root.SetErr(buf)
root.SetArgs(args)
c, err = root.ExecuteContextC(ctx)
return c, buf.String(), err
}
func resetCommandLineFlagSet() { func resetCommandLineFlagSet() {
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError) pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
} }
@ -178,6 +189,35 @@ func TestExecuteContext(t *testing.T) {
} }
} }
func TestExecuteContextC(t *testing.T) {
ctx := context.TODO()
ctxRun := func(cmd *Command, args []string) {
if cmd.Context() != ctx {
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
}
}
rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}
childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)
if _, _, err := executeCommandWithContextC(ctx, rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}
func TestExecute_NoContext(t *testing.T) { func TestExecute_NoContext(t *testing.T) {
run := func(cmd *Command, args []string) { run := func(cmd *Command, args []string) {
if cmd.Context() != context.Background() { if cmd.Context() != context.Background() {

View file

@ -221,6 +221,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Unable to find the real command. E.g., <program> someInvalidCmd <TAB> // Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs) return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
} }
finalCmd.ctx = c.ctx
// Check if we are doing flag value completion before parsing the flags. // Check if we are doing flag value completion before parsing the flags.
// This is important because if we are completing a flag value, we need to also // This is important because if we are completing a flag value, we need to also

View file

@ -2,6 +2,7 @@ package cobra
import ( import (
"bytes" "bytes"
"context"
"strings" "strings"
"testing" "testing"
) )
@ -1203,6 +1204,48 @@ func TestFlagDirFilterCompletionInGo(t *testing.T) {
} }
} }
func TestValidArgsFuncCmdContext(t *testing.T) {
validArgsFunc := func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
ctx := cmd.Context()
if ctx == nil {
t.Error("Received nil context in completion func")
} else if ctx.Value("testKey") != "123" {
t.Error("Received invalid context")
}
return nil, ShellCompDirectiveDefault
}
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "childCmd",
ValidArgsFunction: validArgsFunc,
Run: emptyRun,
}
rootCmd.AddCommand(childCmd)
//nolint:golint,staticcheck // We can safely use a basic type as key in tests.
ctx := context.WithValue(context.Background(), "testKey", "123")
// Test completing an empty string on the childCmd
_, output, err := executeCommandWithContextC(ctx, rootCmd, ShellCompNoDescRequestCmd, "childCmd", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
expected := strings.Join([]string{
":0",
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")
if output != expected {
t.Errorf("expected: %q, got: %q", expected, output)
}
}
func TestValidArgsFuncSingleCmd(t *testing.T) { func TestValidArgsFuncSingleCmd(t *testing.T) {
rootCmd := &Command{ rootCmd := &Command{
Use: "root", Use: "root",