From b07c5cdd6f0b505dd2c9d6ab27b5db6e09929be2 Mon Sep 17 00:00:00 2001 From: Ionut Nicula Date: Sat, 14 Sep 2024 20:31:57 +0300 Subject: [PATCH] Fix shorthand combination edge case in c.Traverse() code path --- command.go | 29 ++++++++++++++++++++++++ command_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/command.go b/command.go index cc20afed..1d2df277 100644 --- a/command.go +++ b/command.go @@ -851,6 +851,35 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) { // A flag without a value, or with an `=` separated value case isFlagArg(arg): flags = append(flags, arg) + + if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 { + continue // Not a shorthand combination, so nothing more to do. + } + + shorthandCombination := arg[1:] // Skip leading "-" + lastPos := len(shorthandCombination) - 1 + for i, shorthand := range shorthandCombination { + if shortHasNoOptDefVal(string(shorthand), c.Flags()) { + continue + } + + // We found a shorthand that needs a value. + + if i == lastPos { + // Since we're at the end of the shorthand combination, this means that the + // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', + // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). + inFlag = true + } else { + // Since the shorthand combination doesn't end here, this means that the + // value for the shorthand is given in the same argument, meaning we don't + // have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are + // boolean flags, and -f is a flag that needs a value). + } + + break + } + continue } diff --git a/command_test.go b/command_test.go index 2f0f4b16..9fb0bdfc 100644 --- a/command_test.go +++ b/command_test.go @@ -2253,7 +2253,7 @@ func TestTraverseWithParentFlags(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if len(args) != 1 && args[0] != "--add" { + if len(args) != 1 || args[0] != "--int" { t.Errorf("Wrong args: %v", args) } if c.Name() != childCmd.Name() { @@ -2261,6 +2261,62 @@ func TestTraverseWithParentFlags(t *testing.T) { } } +func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) { + rootCmd := &Command{Use: "root", TraverseChildren: true} + stringVal := rootCmd.Flags().StringP("str", "s", "", "") + boolVal := rootCmd.Flags().BoolP("bool", "b", false, "") + + childCmd := &Command{Use: "child"} + childCmd.Flags().Int("int", -1, "") + + rootCmd.AddCommand(childCmd) + + c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if len(args) != 1 || args[0] != "--int" { + t.Errorf("Wrong args: %v", args) + } + if c.Name() != childCmd.Name() { + t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name()) + } + if *stringVal != "ok" { + t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal) + } + if !*boolVal { + t.Errorf("Expected -b to be set") + } +} + +func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) { + rootCmd := &Command{Use: "root", TraverseChildren: true} + stringVal := rootCmd.Flags().StringP("str", "s", "", "") + boolVal := rootCmd.Flags().BoolP("bool", "b", false, "") + + childCmd := &Command{Use: "child"} + childCmd.Flags().Int("int", -1, "") + + rootCmd.AddCommand(childCmd) + + c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if len(args) != 1 || args[0] != "--int" { + t.Errorf("Wrong args: %v", args) + } + if c.Name() != childCmd.Name() { + t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name()) + } + if *stringVal != "child" { + t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal) + } + if !*boolVal { + t.Errorf("Expected -b to be set") + } +} + func TestTraverseNoParentFlags(t *testing.T) { rootCmd := &Command{Use: "root", TraverseChildren: true} rootCmd.Flags().String("foo", "", "foo things") @@ -2312,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if len(args) != 1 && args[0] != "--str" { + if len(args) != 1 || args[0] != "--str" { t.Errorf("Wrong args: %v", args) } if c.Name() != childCmd.Name() {