From b19c5d4ab1d1b912c0a7be1995871b37db84355b Mon Sep 17 00:00:00 2001
From: Sean Pino <seanpino@gmail.com>
Date: Tue, 20 Feb 2024 10:38:54 -0500
Subject: [PATCH 1/3] Multiple Required Flags

---
 completions_test.go        | 34 +++++++++++++++++++++++++++-------
 shell_completions.go       | 26 ++++++++++++++++++++++++++
 site/content/user_guide.md | 18 ++++++++++++++++++
 3 files changed, 71 insertions(+), 7 deletions(-)

diff --git a/completions_test.go b/completions_test.go
index df153fcf..28cdcdf5 100644
--- a/completions_test.go
+++ b/completions_test.go
@@ -828,13 +828,17 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	requiredFlag := rootCmd.Flags().Lookup("requiredFlag")
 
 	rootCmd.PersistentFlags().IntP("requiredPersistent", "p", -1, "required persistent")
-	assertNoErr(t, rootCmd.MarkPersistentFlagRequired("requiredPersistent"))
+	rootCmd.PersistentFlags().Float64P("requiredPersistentFloat", "f", -1, "required persistent float")
+	assertNoErr(t, rootCmd.MarkPersistentFlagsRequired("requiredPersistent", "requiredPersistentFloat"))
 	requiredPersistent := rootCmd.PersistentFlags().Lookup("requiredPersistent")
+	requiredPersistentFloat := rootCmd.PersistentFlags().Lookup("requiredPersistentFloat")
 
 	rootCmd.Flags().StringP("release", "R", "", "Release name")
 
-	childCmd.Flags().BoolP("subRequired", "s", false, "sub required flag")
-	assertNoErr(t, childCmd.MarkFlagRequired("subRequired"))
+	childCmd.Flags().BoolP("subRequiredOne", "s", false, "first sub required flag")
+	childCmd.Flags().BoolP("subRequiredTwo", "z", false, "second sub required flag")
+	assertNoErr(t, childCmd.MarkFlagsRequired("subRequiredOne", "subRequiredTwo"))
+
 	childCmd.Flags().BoolP("subNotRequired", "n", false, "sub not required flag")
 
 	// Test that a required flag is suggested even without the - prefix
@@ -851,6 +855,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 		"-r",
 		"--requiredPersistent",
 		"-p",
+		"--requiredPersistentFloat",
+		"-f",
 		"realArg",
 		":4",
 		"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@@ -870,6 +876,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 		"-r",
 		"--requiredPersistent",
 		"-p",
+		"--requiredPersistentFloat",
+		"-f",
 		":4",
 		"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
 
@@ -901,8 +909,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	expected = strings.Join([]string{
 		"--requiredPersistent",
 		"-p",
-		"--subRequired",
+		"--requiredPersistentFloat",
+		"-f",
+		"--subRequiredOne",
 		"-s",
+		"--subRequiredTwo",
+		"-z",
 		"subArg",
 		":4",
 		"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@@ -919,8 +931,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	expected = strings.Join([]string{
 		"--requiredPersistent",
 		"-p",
-		"--subRequired",
+		"--requiredPersistentFloat",
+		"-f",
+		"--subRequiredOne",
 		"-s",
+		"--subRequiredTwo",
+		"-z",
 		":4",
 		"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
 
@@ -953,6 +969,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	expected = strings.Join([]string{
 		"--requiredPersistent",
 		"-p",
+		"--requiredPersistentFloat",
+		"-f",
 		"realArg",
 		":4",
 		"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@@ -962,12 +980,13 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	}
 
 	// Test that when a persistent required flag is present, it is not suggested anymore
-	output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "")
+	output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
 	// Reset the flag for the next command
 	requiredPersistent.Changed = false
+	requiredPersistentFloat.Changed = false
 
 	expected = strings.Join([]string{
 		"childCmd",
@@ -984,13 +1003,14 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
 	}
 
 	// Test that when all required flags are present, normal completion is done
-	output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "")
+	output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
 	// Reset the flags for the next command
 	requiredFlag.Changed = false
 	requiredPersistent.Changed = false
+	requiredPersistentFloat.Changed = false
 
 	expected = strings.Join([]string{
 		"realArg",
diff --git a/shell_completions.go b/shell_completions.go
index b035742d..594c1060 100644
--- a/shell_completions.go
+++ b/shell_completions.go
@@ -25,6 +25,13 @@ func (c *Command) MarkFlagRequired(name string) error {
 	return MarkFlagRequired(c.Flags(), name)
 }
 
+// MarkFlagsRequired instructs the various shell completion implementations to
+// prioritize the named flags when performing completion,
+// and causes your command to report an error if invoked without any of the flags.
+func (c *Command) MarkFlagsRequired(names ...string) error {
+	return MarkFlagsRequired(c.Flags(), names...)
+}
+
 // MarkPersistentFlagRequired instructs the various shell completion implementations to
 // prioritize the named persistent flag when performing completion,
 // and causes your command to report an error if invoked without the flag.
@@ -32,6 +39,13 @@ func (c *Command) MarkPersistentFlagRequired(name string) error {
 	return MarkFlagRequired(c.PersistentFlags(), name)
 }
 
+// MarkPersistentFlagsRequired instructs the various shell completion implementations to
+// prioritize the named persistent flags when performing completion,
+// and causes your command to report an error if invoked without any of the flags.
+func (c *Command) MarkPersistentFlagsRequired(names ...string) error {
+	return MarkFlagsRequired(c.PersistentFlags(), names...)
+}
+
 // MarkFlagRequired instructs the various shell completion implementations to
 // prioritize the named flag when performing completion,
 // and causes your command to report an error if invoked without the flag.
@@ -39,6 +53,18 @@ func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
 	return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
 }
 
+// MarkFlagsRequired instructs the various shell completion implementations to
+// prioritize the named flags when performing completion,
+// and causes your command to report an error if invoked without any of the flags.
+func MarkFlagsRequired(flags *pflag.FlagSet, names ...string) error {
+	for _, name := range names {
+		if err := flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // MarkFlagFilename instructs the various shell completion implementations to
 // limit completions for the named flag to the specified file extensions.
 func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
diff --git a/site/content/user_guide.md b/site/content/user_guide.md
index 3b42ef04..1b7a58d2 100644
--- a/site/content/user_guide.md
+++ b/site/content/user_guide.md
@@ -331,6 +331,24 @@ rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (re
 rootCmd.MarkPersistentFlagRequired("region")
 ```
 
+### Multiple Required flags
+
+If your command has multiple required flags that are not [grouped](#flag-groups) to report an error
+when one or more flags have not been set, mark them as required:
+```go
+rootCmd.Flags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
+rootCmd.Flags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
+rootCmd.MarkFlagsRequired("region", "failoverRegion")
+```
+
+Or, for multiple persistent flags:
+```go
+rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
+rootCmd.PersistentFlags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
+rootCmd.MarkPersistentFlagsRequired("region", "failoverRegion")
+```
+
+
 ### Flag Groups
 
 If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then

From 0967a36905aa53f6c2ea6ac92481f4d5a0666344 Mon Sep 17 00:00:00 2001
From: Sean Pino <seanpino@gmail.com>
Date: Tue, 20 Feb 2024 10:54:19 -0500
Subject: [PATCH 2/3] simplify mark flags

---
 shell_completions.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/shell_completions.go b/shell_completions.go
index 594c1060..376c60d7 100644
--- a/shell_completions.go
+++ b/shell_completions.go
@@ -58,7 +58,7 @@ func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
 // and causes your command to report an error if invoked without any of the flags.
 func MarkFlagsRequired(flags *pflag.FlagSet, names ...string) error {
 	for _, name := range names {
-		if err := flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}); err != nil {
+		if err := MarkFlagRequired(flags, name); err != nil {
 			return err
 		}
 	}

From 6390763171dfbbf6fd99fb96c7f089840ae8555f Mon Sep 17 00:00:00 2001
From: Sean Pino <seanpino@gmail.com>
Date: Tue, 20 Feb 2024 13:05:52 -0500
Subject: [PATCH 3/3] Fixed typo in user guide

---
 site/content/user_guide.md | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/site/content/user_guide.md b/site/content/user_guide.md
index 1b7a58d2..a1cf00a1 100644
--- a/site/content/user_guide.md
+++ b/site/content/user_guide.md
@@ -337,15 +337,15 @@ If your command has multiple required flags that are not [grouped](#flag-groups)
 when one or more flags have not been set, mark them as required:
 ```go
 rootCmd.Flags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
-rootCmd.Flags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
-rootCmd.MarkFlagsRequired("region", "failoverRegion")
+rootCmd.Flags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)")
+rootCmd.MarkFlagsRequired("region", "failover")
 ```
 
 Or, for multiple persistent flags:
 ```go
 rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
-rootCmd.PersistentFlags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
-rootCmd.MarkPersistentFlagsRequired("region", "failoverRegion")
+rootCmd.PersistentFlags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)")
+rootCmd.MarkPersistentFlagsRequired("region", "failover")
 ```