mirror of
https://github.com/spf13/cobra
synced 2024-12-26 22:37:08 +00:00
fix: don't remove flag value that matches subcommand name (#1781)
When the command searches args to find the arg matching a particular subcommand name, it needs to ignore flag values, as it is possible that the value for a flag might match the name of the sub command. This change improves argsMinusFirstX() to ignore flag values when it searches for the X to exclude from the result.
This commit is contained in:
parent
cc7e235fc2
commit
6b0bd3076c
2 changed files with 121 additions and 8 deletions
40
command.go
40
command.go
|
@ -655,13 +655,37 @@ Loop:
|
|||
|
||||
// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
|
||||
// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
|
||||
func argsMinusFirstX(args []string, x string) []string {
|
||||
for i, y := range args {
|
||||
if x == y {
|
||||
ret := []string{}
|
||||
ret = append(ret, args[:i]...)
|
||||
ret = append(ret, args[i+1:]...)
|
||||
return ret
|
||||
// Special care needs to be taken not to remove a flag value.
|
||||
func (c *Command) argsMinusFirstX(args []string, x string) []string {
|
||||
if len(args) == 0 {
|
||||
return args
|
||||
}
|
||||
c.mergePersistentFlags()
|
||||
flags := c.Flags()
|
||||
|
||||
Loop:
|
||||
for pos := 0; pos < len(args); pos++ {
|
||||
s := args[pos]
|
||||
switch {
|
||||
case s == "--":
|
||||
// -- means we have reached the end of the parseable args. Break out of the loop now.
|
||||
break Loop
|
||||
case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags):
|
||||
fallthrough
|
||||
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
|
||||
// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
|
||||
// over the next arg, because that is the value of this flag.
|
||||
pos++
|
||||
continue
|
||||
case !strings.HasPrefix(s, "-"):
|
||||
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
|
||||
// return the args, excluding the one at this position.
|
||||
if s == x {
|
||||
ret := []string{}
|
||||
ret = append(ret, args[:pos]...)
|
||||
ret = append(ret, args[pos+1:]...)
|
||||
return ret
|
||||
}
|
||||
}
|
||||
}
|
||||
return args
|
||||
|
@ -686,7 +710,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
|
|||
|
||||
cmd := c.findNext(nextSubCmd)
|
||||
if cmd != nil {
|
||||
return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd))
|
||||
return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
|
||||
}
|
||||
return c, innerArgs
|
||||
}
|
||||
|
|
|
@ -2603,3 +2603,92 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {
|
|||
checkStringContains(t, output, HelpFlag)
|
||||
checkStringOmits(t, output, VersionFlag)
|
||||
}
|
||||
|
||||
func TestFind(t *testing.T) {
|
||||
var foo, bar string
|
||||
root := &Command{
|
||||
Use: "root",
|
||||
}
|
||||
root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "")
|
||||
root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "")
|
||||
|
||||
child := &Command{
|
||||
Use: "child",
|
||||
}
|
||||
root.AddCommand(child)
|
||||
|
||||
testCases := []struct {
|
||||
args []string
|
||||
expectedFoundArgs []string
|
||||
}{
|
||||
{
|
||||
[]string{"child"},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
[]string{"child", "child"},
|
||||
[]string{"child"},
|
||||
},
|
||||
{
|
||||
[]string{"child", "foo", "child", "bar", "child", "baz", "child"},
|
||||
[]string{"foo", "child", "bar", "child", "baz", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"-f", "child", "child"},
|
||||
[]string{"-f", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"child", "-f", "child"},
|
||||
[]string{"-f", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"-b", "child", "child"},
|
||||
[]string{"-b", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"child", "-b", "child"},
|
||||
[]string{"-b", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"child", "-b"},
|
||||
[]string{"-b"},
|
||||
},
|
||||
{
|
||||
[]string{"-b", "-f", "child", "child"},
|
||||
[]string{"-b", "-f", "child"},
|
||||
},
|
||||
{
|
||||
[]string{"-f", "child", "-b", "something", "child"},
|
||||
[]string{"-f", "child", "-b", "something"},
|
||||
},
|
||||
{
|
||||
[]string{"-f", "child", "child", "-b"},
|
||||
[]string{"-f", "child", "-b"},
|
||||
},
|
||||
{
|
||||
[]string{"-f=child", "-b=something", "child"},
|
||||
[]string{"-f=child", "-b=something"},
|
||||
},
|
||||
{
|
||||
[]string{"--foo", "child", "--bar", "something", "child"},
|
||||
[]string{"--foo", "child", "--bar", "something"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) {
|
||||
cmd, foundArgs, err := root.Find(tc.args)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if cmd != child {
|
||||
t.Fatal("Expected cmd to be child, but it was not")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) {
|
||||
t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue