mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Add Command.SetContext (#1551)
Increases flexibility in how Contexts can be used with Cobra.
This commit is contained in:
parent
5d066b77b5
commit
f848943afd
2 changed files with 109 additions and 0 deletions
|
@ -230,6 +230,12 @@ func (c *Command) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetContext sets context for the command. It is set to context.Background by default and will be overwritten by
|
||||||
|
// Command.ExecuteContext or Command.ExecuteContextC
|
||||||
|
func (c *Command) SetContext(ctx context.Context) {
|
||||||
|
c.ctx = ctx
|
||||||
|
}
|
||||||
|
|
||||||
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
|
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
|
||||||
// particularly useful when testing.
|
// particularly useful when testing.
|
||||||
func (c *Command) SetArgs(a []string) {
|
func (c *Command) SetArgs(a []string) {
|
||||||
|
|
103
command_test.go
103
command_test.go
|
@ -2058,3 +2058,106 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
checkStringContains(t, output, "unknown flag: --unknown")
|
checkStringContains(t, output, "unknown flag: --unknown")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetContext(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
val := "foobar"
|
||||||
|
root := &Command{
|
||||||
|
Use: "root",
|
||||||
|
Run: func(cmd *Command, args []string) {
|
||||||
|
key := cmd.Context().Value(key{})
|
||||||
|
got, ok := key.(string)
|
||||||
|
if !ok {
|
||||||
|
t.Error("key not found in context")
|
||||||
|
}
|
||||||
|
if got != val {
|
||||||
|
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), key{}, val)
|
||||||
|
root.SetContext(ctx)
|
||||||
|
err := root.Execute()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContextPreRun(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
val := "barr"
|
||||||
|
root := &Command{
|
||||||
|
Use: "root",
|
||||||
|
PreRun: func(cmd *Command, args []string) {
|
||||||
|
ctx := context.WithValue(cmd.Context(), key{}, val)
|
||||||
|
cmd.SetContext(ctx)
|
||||||
|
},
|
||||||
|
Run: func(cmd *Command, args []string) {
|
||||||
|
val := cmd.Context().Value(key{})
|
||||||
|
got, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
t.Error("key not found in context")
|
||||||
|
}
|
||||||
|
if got != val {
|
||||||
|
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := root.Execute()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContextPreRunOverwrite(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
val := "blah"
|
||||||
|
root := &Command{
|
||||||
|
Use: "root",
|
||||||
|
Run: func(cmd *Command, args []string) {
|
||||||
|
key := cmd.Context().Value(key{})
|
||||||
|
_, ok := key.(string)
|
||||||
|
if ok {
|
||||||
|
t.Error("key found in context when not expected")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(context.Background(), key{}, val)
|
||||||
|
root.SetContext(ctx)
|
||||||
|
err := root.ExecuteContext(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContextPersistentPreRun(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
val := "barbar"
|
||||||
|
root := &Command{
|
||||||
|
Use: "root",
|
||||||
|
PersistentPreRun: func(cmd *Command, args []string) {
|
||||||
|
ctx := context.WithValue(cmd.Context(), key{}, val)
|
||||||
|
cmd.SetContext(ctx)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
child := &Command{
|
||||||
|
Use: "child",
|
||||||
|
Run: func(cmd *Command, args []string) {
|
||||||
|
key := cmd.Context().Value(key{})
|
||||||
|
got, ok := key.(string)
|
||||||
|
if !ok {
|
||||||
|
t.Error("key not found in context")
|
||||||
|
}
|
||||||
|
if got != val {
|
||||||
|
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
root.AddCommand(child)
|
||||||
|
root.SetArgs([]string{"child"})
|
||||||
|
err := root.Execute()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue