Fix shorthand combination edge case in c.Traverse() code path

This commit is contained in:
Ionut Nicula 2024-09-14 20:31:57 +03:00
parent 8ba575e694
commit b07c5cdd6f
2 changed files with 87 additions and 2 deletions

View file

@ -851,6 +851,35 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) {
// A flag without a value, or with an `=` separated value
case isFlagArg(arg):
flags = append(flags, arg)
if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
continue // Not a shorthand combination, so nothing more to do.
}
shorthandCombination := arg[1:] // Skip leading "-"
lastPos := len(shorthandCombination) - 1
for i, shorthand := range shorthandCombination {
if shortHasNoOptDefVal(string(shorthand), c.Flags()) {
continue
}
// We found a shorthand that needs a value.
if i == lastPos {
// Since we're at the end of the shorthand combination, this means that the
// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
inFlag = true
} else {
// Since the shorthand combination doesn't end here, this means that the
// value for the shorthand is given in the same argument, meaning we don't
// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
// boolean flags, and -f is a flag that needs a value).
}
break
}
continue
}

View file

@ -2253,7 +2253,7 @@ func TestTraverseWithParentFlags(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--add" {
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
@ -2261,6 +2261,62 @@ func TestTraverseWithParentFlags(t *testing.T) {
}
}
func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")
childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")
rootCmd.AddCommand(childCmd)
c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "ok" {
t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}
func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")
childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")
rootCmd.AddCommand(childCmd)
c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "child" {
t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}
func TestTraverseNoParentFlags(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
rootCmd.Flags().String("foo", "", "foo things")
@ -2312,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--str" {
if len(args) != 1 || args[0] != "--str" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {