From 8ba575e6947ac701ca0acbcb12a8f468d298d411 Mon Sep 17 00:00:00 2001 From: Ionut Nicula <nicula.iccc@gmail.com> Date: Sat, 14 Sep 2024 17:08:27 +0300 Subject: [PATCH 1/4] Fix shorthand combination edge case in c.Find() code path Fixes: #2188 --- command.go | 68 +++++++++++++++++++++++++++++++++++++------------ command_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 17 deletions(-) diff --git a/command.go b/command.go index 2df6975f..cc20afed 100644 --- a/command.go +++ b/command.go @@ -643,13 +643,14 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool { return flag.NoOptDefVal != "" } -func stripFlags(args []string, c *Command) []string { +func stripFlags(args []string, c *Command) ([]string, []string) { if len(args) == 0 { - return args + return args, nil } c.mergePersistentFlags() commands := []string{} + flagsThatConsumeNextArg := []string{} // We use this to avoid repeating the same lengthy logic for parsing shorthand combinations in argsMinusFirstX flags := c.Flags() Loop: @@ -665,31 +666,70 @@ Loop: // delete arg from args. fallthrough // (do the same as below) case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags): + flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s) // If '-f arg' then // delete 'arg' from args or break the loop if len(args) <= 1. if len(args) <= 1 { break Loop } else { args = args[1:] - continue + } + case strings.HasPrefix(s, "-") && !strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && len(s) > 2: + shorthandCombination := s[1:] // Skip the leading "-" + lastPos := len(shorthandCombination) - 1 + for i, shorthand := range shorthandCombination { + if shortHasNoOptDefVal(string(shorthand), flags) { + continue + } + + // We found a shorthand that needs a value. + + if i == lastPos { + // Since we're at the end of the shorthand combination, this means that the + // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', + // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). + + // The whole combination will take a value. + flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s) + + if len(args) <= 1 { + break Loop + } else { + args = args[1:] + } + } else { + // Since the shorthand combination doesn't end here, this means that the + // value for the shorthand is given in the same argument, meaning we don't + // have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are + // boolean flags, and -f is a flag that needs a value). + } + + break } case s != "" && !strings.HasPrefix(s, "-"): commands = append(commands, s) } } - return commands + return commands, flagsThatConsumeNextArg } // argsMinusFirstX removes only the first x from args. Otherwise, commands that look like // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). // Special care needs to be taken not to remove a flag value. -func (c *Command) argsMinusFirstX(args []string, x string) []string { +func (c *Command) argsMinusFirstX(args, flagsThatConsumeNextArg []string, x string) []string { if len(args) == 0 { return args } - c.mergePersistentFlags() - flags := c.Flags() + + consumesNextArg := func(flag string) bool { + for _, f := range flagsThatConsumeNextArg { + if flag == f { + return true + } + } + return false + } Loop: for pos := 0; pos < len(args); pos++ { @@ -698,13 +738,8 @@ Loop: case s == "--": // -- means we have reached the end of the parseable args. Break out of the loop now. break Loop - case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags): - fallthrough - case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags): - // This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip - // over the next arg, because that is the value of this flag. + case consumesNextArg(s): pos++ - continue case !strings.HasPrefix(s, "-"): // This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so, // return the args, excluding the one at this position. @@ -730,7 +765,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) { var innerfind func(*Command, []string) (*Command, []string) innerfind = func(c *Command, innerArgs []string) (*Command, []string) { - argsWOflags := stripFlags(innerArgs, c) + argsWOflags, flagsThatConsumeNextArg := stripFlags(innerArgs, c) if len(argsWOflags) == 0 { return c, innerArgs } @@ -738,14 +773,15 @@ func (c *Command) Find(args []string) (*Command, []string, error) { cmd := c.findNext(nextSubCmd) if cmd != nil { - return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd)) + return innerfind(cmd, c.argsMinusFirstX(innerArgs, flagsThatConsumeNextArg, nextSubCmd)) } return c, innerArgs } commandFound, a := innerfind(c, args) if commandFound.Args == nil { - return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound)) + argsWOflags, _ := stripFlags(a, commandFound) + return commandFound, a, legacyArgs(commandFound, argsWOflags) } return commandFound, a, nil } diff --git a/command_test.go b/command_test.go index 9ce7a529..2f0f4b16 100644 --- a/command_test.go +++ b/command_test.go @@ -693,6 +693,30 @@ func TestStripFlags(t *testing.T) { []string{"-p", "bar"}, []string{"bar"}, }, + { + []string{"-s", "value", "bar"}, + []string{"bar"}, + }, + { + []string{"-s=value", "bar"}, + []string{"bar"}, + }, + { + []string{"-svalue", "bar"}, + []string{"bar"}, + }, + { + []string{"-ps", "value", "bar"}, + []string{"bar"}, + }, + { + []string{"-ps=value", "bar"}, + []string{"bar"}, + }, + { + []string{"-psvalue", "bar"}, + []string{"bar"}, + }, } c := &Command{Use: "c", Run: emptyRun} @@ -702,7 +726,7 @@ func TestStripFlags(t *testing.T) { c.Flags().BoolP("bool", "b", false, "") for i, test := range tests { - got := stripFlags(test.input, c) + got, _ := stripFlags(test.input, c) if !reflect.DeepEqual(test.output, got) { t.Errorf("(%v) Expected: %v, got: %v", i, test.output, got) } @@ -2688,11 +2712,13 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) { func TestFind(t *testing.T) { var foo, bar string + var persist bool root := &Command{ Use: "root", } root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "") root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "") + root.PersistentFlags().BoolVarP(&persist, "persist", "p", false, "") child := &Command{ Use: "child", @@ -2755,6 +2781,38 @@ func TestFind(t *testing.T) { []string{"--foo", "child", "--bar", "something", "child"}, []string{"--foo", "child", "--bar", "something"}, }, + { + []string{"-f", "value", "child"}, + []string{"-f", "value"}, + }, + { + []string{"-f=value", "child"}, + []string{"-f=value"}, + }, + { + []string{"-fvalue", "child"}, + []string{"-fvalue"}, + }, + { + []string{"-pf", "value", "child"}, + []string{"-pf", "value"}, + }, + { + []string{"-pf=value", "child"}, + []string{"-pf=value"}, + }, + { + []string{"-pfvalue", "child"}, + []string{"-pfvalue"}, + }, + { + []string{"-pf", "child", "child"}, + []string{"-pf", "child"}, + }, + { + []string{"-pf", "child", "-pb", "something", "child"}, + []string{"-pf", "child", "-pb", "something"}, + }, } for _, tc := range testCases { From b07c5cdd6f0b505dd2c9d6ab27b5db6e09929be2 Mon Sep 17 00:00:00 2001 From: Ionut Nicula <nicula.iccc@gmail.com> Date: Sat, 14 Sep 2024 20:31:57 +0300 Subject: [PATCH 2/4] Fix shorthand combination edge case in c.Traverse() code path --- command.go | 29 ++++++++++++++++++++++++ command_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/command.go b/command.go index cc20afed..1d2df277 100644 --- a/command.go +++ b/command.go @@ -851,6 +851,35 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) { // A flag without a value, or with an `=` separated value case isFlagArg(arg): flags = append(flags, arg) + + if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 { + continue // Not a shorthand combination, so nothing more to do. + } + + shorthandCombination := arg[1:] // Skip leading "-" + lastPos := len(shorthandCombination) - 1 + for i, shorthand := range shorthandCombination { + if shortHasNoOptDefVal(string(shorthand), c.Flags()) { + continue + } + + // We found a shorthand that needs a value. + + if i == lastPos { + // Since we're at the end of the shorthand combination, this means that the + // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', + // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). + inFlag = true + } else { + // Since the shorthand combination doesn't end here, this means that the + // value for the shorthand is given in the same argument, meaning we don't + // have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are + // boolean flags, and -f is a flag that needs a value). + } + + break + } + continue } diff --git a/command_test.go b/command_test.go index 2f0f4b16..9fb0bdfc 100644 --- a/command_test.go +++ b/command_test.go @@ -2253,7 +2253,7 @@ func TestTraverseWithParentFlags(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if len(args) != 1 && args[0] != "--add" { + if len(args) != 1 || args[0] != "--int" { t.Errorf("Wrong args: %v", args) } if c.Name() != childCmd.Name() { @@ -2261,6 +2261,62 @@ func TestTraverseWithParentFlags(t *testing.T) { } } +func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) { + rootCmd := &Command{Use: "root", TraverseChildren: true} + stringVal := rootCmd.Flags().StringP("str", "s", "", "") + boolVal := rootCmd.Flags().BoolP("bool", "b", false, "") + + childCmd := &Command{Use: "child"} + childCmd.Flags().Int("int", -1, "") + + rootCmd.AddCommand(childCmd) + + c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if len(args) != 1 || args[0] != "--int" { + t.Errorf("Wrong args: %v", args) + } + if c.Name() != childCmd.Name() { + t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name()) + } + if *stringVal != "ok" { + t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal) + } + if !*boolVal { + t.Errorf("Expected -b to be set") + } +} + +func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) { + rootCmd := &Command{Use: "root", TraverseChildren: true} + stringVal := rootCmd.Flags().StringP("str", "s", "", "") + boolVal := rootCmd.Flags().BoolP("bool", "b", false, "") + + childCmd := &Command{Use: "child"} + childCmd.Flags().Int("int", -1, "") + + rootCmd.AddCommand(childCmd) + + c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if len(args) != 1 || args[0] != "--int" { + t.Errorf("Wrong args: %v", args) + } + if c.Name() != childCmd.Name() { + t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name()) + } + if *stringVal != "child" { + t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal) + } + if !*boolVal { + t.Errorf("Expected -b to be set") + } +} + func TestTraverseNoParentFlags(t *testing.T) { rootCmd := &Command{Use: "root", TraverseChildren: true} rootCmd.Flags().String("foo", "", "foo things") @@ -2312,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if len(args) != 1 && args[0] != "--str" { + if len(args) != 1 || args[0] != "--str" { t.Errorf("Wrong args: %v", args) } if c.Name() != childCmd.Name() { From cc1f750da2cf0c534443d65e5386a87f12cbdeb3 Mon Sep 17 00:00:00 2001 From: Ionut Nicula <nicula.iccc@gmail.com> Date: Thu, 19 Sep 2024 19:51:37 +0300 Subject: [PATCH 3/4] Simplify code Use a helper function for both code paths instead of duplicating the logic. --- command.go | 83 ++++++++++++++++++++---------------------------------- 1 file changed, 30 insertions(+), 53 deletions(-) diff --git a/command.go b/command.go index 1d2df277..2d4d4bdb 100644 --- a/command.go +++ b/command.go @@ -643,6 +643,27 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool { return flag.NoOptDefVal != "" } +func shorthandCombinationNeedsNextArg(combination string, flags *flag.FlagSet) bool { + lastPos := len(combination) - 1 + for i, shorthand := range combination { + if !shortHasNoOptDefVal(string(shorthand), flags) { + // This shorthand needs a value. + // + // If we're at the end of the shorthand combination, this means that the + // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', + // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). + // + // Otherwise, if the shorthand combination doesn't end here, this means that the value + // for the shorthand is given in the same argument, meaning we don't have to consume the + // next one. (e.g. '-xyzfarg', where -x, -y, -z are boolean flags, and -f is a flag that + // needs a value). + return i == lastPos + } + } + + return false +} + func stripFlags(args []string, c *Command) ([]string, []string) { if len(args) == 0 { return args, nil @@ -675,36 +696,14 @@ Loop: args = args[1:] } case strings.HasPrefix(s, "-") && !strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && len(s) > 2: - shorthandCombination := s[1:] // Skip the leading "-" - lastPos := len(shorthandCombination) - 1 - for i, shorthand := range shorthandCombination { - if shortHasNoOptDefVal(string(shorthand), flags) { - continue - } - - // We found a shorthand that needs a value. - - if i == lastPos { - // Since we're at the end of the shorthand combination, this means that the - // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', - // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). - - // The whole combination will take a value. - flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s) - - if len(args) <= 1 { - break Loop - } else { - args = args[1:] - } + shorthandCombination := s[1:] // Skip leading "-" + if shorthandCombinationNeedsNextArg(shorthandCombination, flags) { + flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s) + if len(args) <= 1 { + break Loop } else { - // Since the shorthand combination doesn't end here, this means that the - // value for the shorthand is given in the same argument, meaning we don't - // have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are - // boolean flags, and -f is a flag that needs a value). + args = args[1:] } - - break } case s != "" && !strings.HasPrefix(s, "-"): commands = append(commands, s) @@ -848,38 +847,16 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) { inFlag = false flags = append(flags, arg) continue - // A flag without a value, or with an `=` separated value + // A flag with an `=` separated value, or a shorthand combination, possibly with a value case isFlagArg(arg): flags = append(flags, arg) - if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 { + if strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 { continue // Not a shorthand combination, so nothing more to do. } shorthandCombination := arg[1:] // Skip leading "-" - lastPos := len(shorthandCombination) - 1 - for i, shorthand := range shorthandCombination { - if shortHasNoOptDefVal(string(shorthand), c.Flags()) { - continue - } - - // We found a shorthand that needs a value. - - if i == lastPos { - // Since we're at the end of the shorthand combination, this means that the - // value for the shorthand is given in the next argument. (e.g. '-xyzf arg', - // where -x, -y, -z are boolean flags, and -f is a flag that needs a value). - inFlag = true - } else { - // Since the shorthand combination doesn't end here, this means that the - // value for the shorthand is given in the same argument, meaning we don't - // have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are - // boolean flags, and -f is a flag that needs a value). - } - - break - } - + inFlag = shorthandCombinationNeedsNextArg(shorthandCombination, c.Flags()) continue } From e8dfaa897c794304098c620ff53130b3f68d28ea Mon Sep 17 00:00:00 2001 From: Ionut Nicula <nicula.iccc@gmail.com> Date: Sat, 12 Oct 2024 12:23:52 +0300 Subject: [PATCH 4/4] Fix linter errors --- command_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/command_test.go b/command_test.go index 9fb0bdfc..f82c8873 100644 --- a/command_test.go +++ b/command_test.go @@ -2239,6 +2239,7 @@ func TestUseDeprecatedFlags(t *testing.T) { checkStringContains(t, output, "This flag is deprecated") } +//nolint:goconst,nolintlint // Disable check for string literal occurrences func TestTraverseWithParentFlags(t *testing.T) { rootCmd := &Command{Use: "root", TraverseChildren: true} rootCmd.Flags().String("str", "", "") @@ -2261,6 +2262,7 @@ func TestTraverseWithParentFlags(t *testing.T) { } } +//nolint:goconst,nolintlint // Disable check for string literal occurrences func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) { rootCmd := &Command{Use: "root", TraverseChildren: true} stringVal := rootCmd.Flags().StringP("str", "s", "", "") @@ -2289,6 +2291,7 @@ func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) { } } +//nolint:goconst,nolintlint // Disable check for string literal occurrences func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) { rootCmd := &Command{Use: "root", TraverseChildren: true} stringVal := rootCmd.Flags().StringP("str", "s", "", "")