From 8ba575e6947ac701ca0acbcb12a8f468d298d411 Mon Sep 17 00:00:00 2001
From: Ionut Nicula <nicula.iccc@gmail.com>
Date: Sat, 14 Sep 2024 17:08:27 +0300
Subject: [PATCH 1/4] Fix shorthand combination edge case in c.Find() code path

Fixes: #2188
---
 command.go      | 68 +++++++++++++++++++++++++++++++++++++------------
 command_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 111 insertions(+), 17 deletions(-)

diff --git a/command.go b/command.go
index 2df6975f..cc20afed 100644
--- a/command.go
+++ b/command.go
@@ -643,13 +643,14 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool {
 	return flag.NoOptDefVal != ""
 }
 
-func stripFlags(args []string, c *Command) []string {
+func stripFlags(args []string, c *Command) ([]string, []string) {
 	if len(args) == 0 {
-		return args
+		return args, nil
 	}
 	c.mergePersistentFlags()
 
 	commands := []string{}
+	flagsThatConsumeNextArg := []string{} // We use this to avoid repeating the same lengthy logic for parsing shorthand combinations in argsMinusFirstX
 	flags := c.Flags()
 
 Loop:
@@ -665,31 +666,70 @@ Loop:
 			// delete arg from args.
 			fallthrough // (do the same as below)
 		case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
+			flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
 			// If '-f arg' then
 			// delete 'arg' from args or break the loop if len(args) <= 1.
 			if len(args) <= 1 {
 				break Loop
 			} else {
 				args = args[1:]
-				continue
+			}
+		case strings.HasPrefix(s, "-") && !strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && len(s) > 2:
+			shorthandCombination := s[1:] // Skip the leading "-"
+			lastPos := len(shorthandCombination) - 1
+			for i, shorthand := range shorthandCombination {
+				if shortHasNoOptDefVal(string(shorthand), flags) {
+					continue
+				}
+
+				// We found a shorthand that needs a value.
+
+				if i == lastPos {
+					// Since we're at the end of the shorthand combination, this means that the
+					// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
+					// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
+
+					// The whole combination will take a value.
+					flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
+
+					if len(args) <= 1 {
+						break Loop
+					} else {
+						args = args[1:]
+					}
+				} else {
+					// Since the shorthand combination doesn't end here, this means that the
+					// value for the shorthand is given in the same argument, meaning we don't
+					// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
+					// boolean flags, and -f is a flag that needs a value).
+				}
+
+				break
 			}
 		case s != "" && !strings.HasPrefix(s, "-"):
 			commands = append(commands, s)
 		}
 	}
 
-	return commands
+	return commands, flagsThatConsumeNextArg
 }
 
 // argsMinusFirstX removes only the first x from args.  Otherwise, commands that look like
 // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
 // Special care needs to be taken not to remove a flag value.
-func (c *Command) argsMinusFirstX(args []string, x string) []string {
+func (c *Command) argsMinusFirstX(args, flagsThatConsumeNextArg []string, x string) []string {
 	if len(args) == 0 {
 		return args
 	}
-	c.mergePersistentFlags()
-	flags := c.Flags()
+
+	consumesNextArg := func(flag string) bool {
+		for _, f := range flagsThatConsumeNextArg {
+			if flag == f {
+				return true
+			}
+		}
+		return false
+	}
 
 Loop:
 	for pos := 0; pos < len(args); pos++ {
@@ -698,13 +738,8 @@ Loop:
 		case s == "--":
 			// -- means we have reached the end of the parseable args. Break out of the loop now.
 			break Loop
-		case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags):
-			fallthrough
-		case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
-			// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
-			// over the next arg, because that is the value of this flag.
+		case consumesNextArg(s):
 			pos++
-			continue
 		case !strings.HasPrefix(s, "-"):
 			// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
 			// return the args, excluding the one at this position.
@@ -730,7 +765,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
 	var innerfind func(*Command, []string) (*Command, []string)
 
 	innerfind = func(c *Command, innerArgs []string) (*Command, []string) {
-		argsWOflags := stripFlags(innerArgs, c)
+		argsWOflags, flagsThatConsumeNextArg := stripFlags(innerArgs, c)
 		if len(argsWOflags) == 0 {
 			return c, innerArgs
 		}
@@ -738,14 +773,15 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
 
 		cmd := c.findNext(nextSubCmd)
 		if cmd != nil {
-			return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
+			return innerfind(cmd, c.argsMinusFirstX(innerArgs, flagsThatConsumeNextArg, nextSubCmd))
 		}
 		return c, innerArgs
 	}
 
 	commandFound, a := innerfind(c, args)
 	if commandFound.Args == nil {
-		return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound))
+		argsWOflags, _ := stripFlags(a, commandFound)
+		return commandFound, a, legacyArgs(commandFound, argsWOflags)
 	}
 	return commandFound, a, nil
 }
diff --git a/command_test.go b/command_test.go
index 9ce7a529..2f0f4b16 100644
--- a/command_test.go
+++ b/command_test.go
@@ -693,6 +693,30 @@ func TestStripFlags(t *testing.T) {
 			[]string{"-p", "bar"},
 			[]string{"bar"},
 		},
+		{
+			[]string{"-s", "value", "bar"},
+			[]string{"bar"},
+		},
+		{
+			[]string{"-s=value", "bar"},
+			[]string{"bar"},
+		},
+		{
+			[]string{"-svalue", "bar"},
+			[]string{"bar"},
+		},
+		{
+			[]string{"-ps", "value", "bar"},
+			[]string{"bar"},
+		},
+		{
+			[]string{"-ps=value", "bar"},
+			[]string{"bar"},
+		},
+		{
+			[]string{"-psvalue", "bar"},
+			[]string{"bar"},
+		},
 	}
 
 	c := &Command{Use: "c", Run: emptyRun}
@@ -702,7 +726,7 @@ func TestStripFlags(t *testing.T) {
 	c.Flags().BoolP("bool", "b", false, "")
 
 	for i, test := range tests {
-		got := stripFlags(test.input, c)
+		got, _ := stripFlags(test.input, c)
 		if !reflect.DeepEqual(test.output, got) {
 			t.Errorf("(%v) Expected: %v, got: %v", i, test.output, got)
 		}
@@ -2688,11 +2712,13 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {
 
 func TestFind(t *testing.T) {
 	var foo, bar string
+	var persist bool
 	root := &Command{
 		Use: "root",
 	}
 	root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "")
 	root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "")
+	root.PersistentFlags().BoolVarP(&persist, "persist", "p", false, "")
 
 	child := &Command{
 		Use: "child",
@@ -2755,6 +2781,38 @@ func TestFind(t *testing.T) {
 			[]string{"--foo", "child", "--bar", "something", "child"},
 			[]string{"--foo", "child", "--bar", "something"},
 		},
+		{
+			[]string{"-f", "value", "child"},
+			[]string{"-f", "value"},
+		},
+		{
+			[]string{"-f=value", "child"},
+			[]string{"-f=value"},
+		},
+		{
+			[]string{"-fvalue", "child"},
+			[]string{"-fvalue"},
+		},
+		{
+			[]string{"-pf", "value", "child"},
+			[]string{"-pf", "value"},
+		},
+		{
+			[]string{"-pf=value", "child"},
+			[]string{"-pf=value"},
+		},
+		{
+			[]string{"-pfvalue", "child"},
+			[]string{"-pfvalue"},
+		},
+		{
+			[]string{"-pf", "child", "child"},
+			[]string{"-pf", "child"},
+		},
+		{
+			[]string{"-pf", "child", "-pb", "something", "child"},
+			[]string{"-pf", "child", "-pb", "something"},
+		},
 	}
 
 	for _, tc := range testCases {

From b07c5cdd6f0b505dd2c9d6ab27b5db6e09929be2 Mon Sep 17 00:00:00 2001
From: Ionut Nicula <nicula.iccc@gmail.com>
Date: Sat, 14 Sep 2024 20:31:57 +0300
Subject: [PATCH 2/4] Fix shorthand combination edge case in c.Traverse() code
 path

---
 command.go      | 29 ++++++++++++++++++++++++
 command_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/command.go b/command.go
index cc20afed..1d2df277 100644
--- a/command.go
+++ b/command.go
@@ -851,6 +851,35 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) {
 		// A flag without a value, or with an `=` separated value
 		case isFlagArg(arg):
 			flags = append(flags, arg)
+
+			if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
+				continue // Not a shorthand combination, so nothing more to do.
+			}
+
+			shorthandCombination := arg[1:] // Skip leading "-"
+			lastPos := len(shorthandCombination) - 1
+			for i, shorthand := range shorthandCombination {
+				if shortHasNoOptDefVal(string(shorthand), c.Flags()) {
+					continue
+				}
+
+				// We found a shorthand that needs a value.
+
+				if i == lastPos {
+					// Since we're at the end of the shorthand combination, this means that the
+					// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
+					// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
+					inFlag = true
+				} else {
+					// Since the shorthand combination doesn't end here, this means that the
+					// value for the shorthand is given in the same argument, meaning we don't
+					// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
+					// boolean flags, and -f is a flag that needs a value).
+				}
+
+				break
+			}
+
 			continue
 		}
 
diff --git a/command_test.go b/command_test.go
index 2f0f4b16..9fb0bdfc 100644
--- a/command_test.go
+++ b/command_test.go
@@ -2253,7 +2253,7 @@ func TestTraverseWithParentFlags(t *testing.T) {
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-	if len(args) != 1 && args[0] != "--add" {
+	if len(args) != 1 || args[0] != "--int" {
 		t.Errorf("Wrong args: %v", args)
 	}
 	if c.Name() != childCmd.Name() {
@@ -2261,6 +2261,62 @@ func TestTraverseWithParentFlags(t *testing.T) {
 	}
 }
 
+func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
+	rootCmd := &Command{Use: "root", TraverseChildren: true}
+	stringVal := rootCmd.Flags().StringP("str", "s", "", "")
+	boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")
+
+	childCmd := &Command{Use: "child"}
+	childCmd.Flags().Int("int", -1, "")
+
+	rootCmd.AddCommand(childCmd)
+
+	c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"})
+	if err != nil {
+		t.Errorf("Unexpected error: %v", err)
+	}
+	if len(args) != 1 || args[0] != "--int" {
+		t.Errorf("Wrong args: %v", args)
+	}
+	if c.Name() != childCmd.Name() {
+		t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
+	}
+	if *stringVal != "ok" {
+		t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal)
+	}
+	if !*boolVal {
+		t.Errorf("Expected -b to be set")
+	}
+}
+
+func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) {
+	rootCmd := &Command{Use: "root", TraverseChildren: true}
+	stringVal := rootCmd.Flags().StringP("str", "s", "", "")
+	boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")
+
+	childCmd := &Command{Use: "child"}
+	childCmd.Flags().Int("int", -1, "")
+
+	rootCmd.AddCommand(childCmd)
+
+	c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"})
+	if err != nil {
+		t.Errorf("Unexpected error: %v", err)
+	}
+	if len(args) != 1 || args[0] != "--int" {
+		t.Errorf("Wrong args: %v", args)
+	}
+	if c.Name() != childCmd.Name() {
+		t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
+	}
+	if *stringVal != "child" {
+		t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal)
+	}
+	if !*boolVal {
+		t.Errorf("Expected -b to be set")
+	}
+}
+
 func TestTraverseNoParentFlags(t *testing.T) {
 	rootCmd := &Command{Use: "root", TraverseChildren: true}
 	rootCmd.Flags().String("foo", "", "foo things")
@@ -2312,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-	if len(args) != 1 && args[0] != "--str" {
+	if len(args) != 1 || args[0] != "--str" {
 		t.Errorf("Wrong args: %v", args)
 	}
 	if c.Name() != childCmd.Name() {

From cc1f750da2cf0c534443d65e5386a87f12cbdeb3 Mon Sep 17 00:00:00 2001
From: Ionut Nicula <nicula.iccc@gmail.com>
Date: Thu, 19 Sep 2024 19:51:37 +0300
Subject: [PATCH 3/4] Simplify code

Use a helper function for both code paths instead of duplicating the
logic.
---
 command.go | 83 ++++++++++++++++++++----------------------------------
 1 file changed, 30 insertions(+), 53 deletions(-)

diff --git a/command.go b/command.go
index 1d2df277..2d4d4bdb 100644
--- a/command.go
+++ b/command.go
@@ -643,6 +643,27 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool {
 	return flag.NoOptDefVal != ""
 }
 
+func shorthandCombinationNeedsNextArg(combination string, flags *flag.FlagSet) bool {
+	lastPos := len(combination) - 1
+	for i, shorthand := range combination {
+		if !shortHasNoOptDefVal(string(shorthand), flags) {
+			// This shorthand needs a value.
+			//
+			// If we're at the end of the shorthand combination, this means that the
+			// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
+			// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
+			//
+			// Otherwise, if the shorthand combination doesn't end here, this means that the value
+			// for the shorthand is given in the same argument, meaning we don't have to consume the
+			// next one. (e.g. '-xyzfarg', where -x, -y, -z are boolean flags, and -f is a flag that
+			// needs a value).
+			return i == lastPos
+		}
+	}
+
+	return false
+}
+
 func stripFlags(args []string, c *Command) ([]string, []string) {
 	if len(args) == 0 {
 		return args, nil
@@ -675,36 +696,14 @@ Loop:
 				args = args[1:]
 			}
 		case strings.HasPrefix(s, "-") && !strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && len(s) > 2:
-			shorthandCombination := s[1:] // Skip the leading "-"
-			lastPos := len(shorthandCombination) - 1
-			for i, shorthand := range shorthandCombination {
-				if shortHasNoOptDefVal(string(shorthand), flags) {
-					continue
-				}
-
-				// We found a shorthand that needs a value.
-
-				if i == lastPos {
-					// Since we're at the end of the shorthand combination, this means that the
-					// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
-					// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
-
-					// The whole combination will take a value.
-					flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
-
-					if len(args) <= 1 {
-						break Loop
-					} else {
-						args = args[1:]
-					}
+			shorthandCombination := s[1:] // Skip leading "-"
+			if shorthandCombinationNeedsNextArg(shorthandCombination, flags) {
+				flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
+				if len(args) <= 1 {
+					break Loop
 				} else {
-					// Since the shorthand combination doesn't end here, this means that the
-					// value for the shorthand is given in the same argument, meaning we don't
-					// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
-					// boolean flags, and -f is a flag that needs a value).
+					args = args[1:]
 				}
-
-				break
 			}
 		case s != "" && !strings.HasPrefix(s, "-"):
 			commands = append(commands, s)
@@ -848,38 +847,16 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) {
 			inFlag = false
 			flags = append(flags, arg)
 			continue
-		// A flag without a value, or with an `=` separated value
+		// A flag with an `=` separated value, or a shorthand combination, possibly with a value
 		case isFlagArg(arg):
 			flags = append(flags, arg)
 
-			if !strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
+			if strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
 				continue // Not a shorthand combination, so nothing more to do.
 			}
 
 			shorthandCombination := arg[1:] // Skip leading "-"
-			lastPos := len(shorthandCombination) - 1
-			for i, shorthand := range shorthandCombination {
-				if shortHasNoOptDefVal(string(shorthand), c.Flags()) {
-					continue
-				}
-
-				// We found a shorthand that needs a value.
-
-				if i == lastPos {
-					// Since we're at the end of the shorthand combination, this means that the
-					// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
-					// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
-					inFlag = true
-				} else {
-					// Since the shorthand combination doesn't end here, this means that the
-					// value for the shorthand is given in the same argument, meaning we don't
-					// have to consume the next one. (e.g. '-xyzfarg', where -x, -y, -z are
-					// boolean flags, and -f is a flag that needs a value).
-				}
-
-				break
-			}
-
+			inFlag = shorthandCombinationNeedsNextArg(shorthandCombination, c.Flags())
 			continue
 		}
 

From e8dfaa897c794304098c620ff53130b3f68d28ea Mon Sep 17 00:00:00 2001
From: Ionut Nicula <nicula.iccc@gmail.com>
Date: Sat, 12 Oct 2024 12:23:52 +0300
Subject: [PATCH 4/4] Fix linter errors

---
 command_test.go | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/command_test.go b/command_test.go
index 9fb0bdfc..f82c8873 100644
--- a/command_test.go
+++ b/command_test.go
@@ -2239,6 +2239,7 @@ func TestUseDeprecatedFlags(t *testing.T) {
 	checkStringContains(t, output, "This flag is deprecated")
 }
 
+//nolint:goconst,nolintlint // Disable check for string literal occurrences
 func TestTraverseWithParentFlags(t *testing.T) {
 	rootCmd := &Command{Use: "root", TraverseChildren: true}
 	rootCmd.Flags().String("str", "", "")
@@ -2261,6 +2262,7 @@ func TestTraverseWithParentFlags(t *testing.T) {
 	}
 }
 
+//nolint:goconst,nolintlint // Disable check for string literal occurrences
 func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
 	rootCmd := &Command{Use: "root", TraverseChildren: true}
 	stringVal := rootCmd.Flags().StringP("str", "s", "", "")
@@ -2289,6 +2291,7 @@ func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
 	}
 }
 
+//nolint:goconst,nolintlint // Disable check for string literal occurrences
 func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) {
 	rootCmd := &Command{Use: "root", TraverseChildren: true}
 	stringVal := rootCmd.Flags().StringP("str", "s", "", "")