Multiple Required Flags

This commit is contained in:
Sean Pino 2024-02-20 10:38:54 -05:00
parent bcfcff729e
commit b19c5d4ab1
3 changed files with 71 additions and 7 deletions

View file

@ -828,13 +828,17 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
requiredFlag := rootCmd.Flags().Lookup("requiredFlag") requiredFlag := rootCmd.Flags().Lookup("requiredFlag")
rootCmd.PersistentFlags().IntP("requiredPersistent", "p", -1, "required persistent") rootCmd.PersistentFlags().IntP("requiredPersistent", "p", -1, "required persistent")
assertNoErr(t, rootCmd.MarkPersistentFlagRequired("requiredPersistent")) rootCmd.PersistentFlags().Float64P("requiredPersistentFloat", "f", -1, "required persistent float")
assertNoErr(t, rootCmd.MarkPersistentFlagsRequired("requiredPersistent", "requiredPersistentFloat"))
requiredPersistent := rootCmd.PersistentFlags().Lookup("requiredPersistent") requiredPersistent := rootCmd.PersistentFlags().Lookup("requiredPersistent")
requiredPersistentFloat := rootCmd.PersistentFlags().Lookup("requiredPersistentFloat")
rootCmd.Flags().StringP("release", "R", "", "Release name") rootCmd.Flags().StringP("release", "R", "", "Release name")
childCmd.Flags().BoolP("subRequired", "s", false, "sub required flag") childCmd.Flags().BoolP("subRequiredOne", "s", false, "first sub required flag")
assertNoErr(t, childCmd.MarkFlagRequired("subRequired")) childCmd.Flags().BoolP("subRequiredTwo", "z", false, "second sub required flag")
assertNoErr(t, childCmd.MarkFlagsRequired("subRequiredOne", "subRequiredTwo"))
childCmd.Flags().BoolP("subNotRequired", "n", false, "sub not required flag") childCmd.Flags().BoolP("subNotRequired", "n", false, "sub not required flag")
// Test that a required flag is suggested even without the - prefix // Test that a required flag is suggested even without the - prefix
@ -851,6 +855,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
"-r", "-r",
"--requiredPersistent", "--requiredPersistent",
"-p", "-p",
"--requiredPersistentFloat",
"-f",
"realArg", "realArg",
":4", ":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@ -870,6 +876,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
"-r", "-r",
"--requiredPersistent", "--requiredPersistent",
"-p", "-p",
"--requiredPersistentFloat",
"-f",
":4", ":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@ -901,8 +909,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{ expected = strings.Join([]string{
"--requiredPersistent", "--requiredPersistent",
"-p", "-p",
"--subRequired", "--requiredPersistentFloat",
"-f",
"--subRequiredOne",
"-s", "-s",
"--subRequiredTwo",
"-z",
"subArg", "subArg",
":4", ":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@ -919,8 +931,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{ expected = strings.Join([]string{
"--requiredPersistent", "--requiredPersistent",
"-p", "-p",
"--subRequired", "--requiredPersistentFloat",
"-f",
"--subRequiredOne",
"-s", "-s",
"--subRequiredTwo",
"-z",
":4", ":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@ -953,6 +969,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{ expected = strings.Join([]string{
"--requiredPersistent", "--requiredPersistent",
"-p", "-p",
"--requiredPersistentFloat",
"-f",
"realArg", "realArg",
":4", ":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
@ -962,12 +980,13 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
} }
// Test that when a persistent required flag is present, it is not suggested anymore // Test that when a persistent required flag is present, it is not suggested anymore
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "") output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
// Reset the flag for the next command // Reset the flag for the next command
requiredPersistent.Changed = false requiredPersistent.Changed = false
requiredPersistentFloat.Changed = false
expected = strings.Join([]string{ expected = strings.Join([]string{
"childCmd", "childCmd",
@ -984,13 +1003,14 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
} }
// Test that when all required flags are present, normal completion is done // Test that when all required flags are present, normal completion is done
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "") output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
if err != nil { if err != nil {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
// Reset the flags for the next command // Reset the flags for the next command
requiredFlag.Changed = false requiredFlag.Changed = false
requiredPersistent.Changed = false requiredPersistent.Changed = false
requiredPersistentFloat.Changed = false
expected = strings.Join([]string{ expected = strings.Join([]string{
"realArg", "realArg",

View file

@ -25,6 +25,13 @@ func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name) return MarkFlagRequired(c.Flags(), name)
} }
// MarkFlagsRequired instructs the various shell completion implementations to
// prioritize the named flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func (c *Command) MarkFlagsRequired(names ...string) error {
return MarkFlagsRequired(c.Flags(), names...)
}
// MarkPersistentFlagRequired instructs the various shell completion implementations to // MarkPersistentFlagRequired instructs the various shell completion implementations to
// prioritize the named persistent flag when performing completion, // prioritize the named persistent flag when performing completion,
// and causes your command to report an error if invoked without the flag. // and causes your command to report an error if invoked without the flag.
@ -32,6 +39,13 @@ func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name) return MarkFlagRequired(c.PersistentFlags(), name)
} }
// MarkPersistentFlagsRequired instructs the various shell completion implementations to
// prioritize the named persistent flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func (c *Command) MarkPersistentFlagsRequired(names ...string) error {
return MarkFlagsRequired(c.PersistentFlags(), names...)
}
// MarkFlagRequired instructs the various shell completion implementations to // MarkFlagRequired instructs the various shell completion implementations to
// prioritize the named flag when performing completion, // prioritize the named flag when performing completion,
// and causes your command to report an error if invoked without the flag. // and causes your command to report an error if invoked without the flag.
@ -39,6 +53,18 @@ func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
} }
// MarkFlagsRequired instructs the various shell completion implementations to
// prioritize the named flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func MarkFlagsRequired(flags *pflag.FlagSet, names ...string) error {
for _, name := range names {
if err := flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}); err != nil {
return err
}
}
return nil
}
// MarkFlagFilename instructs the various shell completion implementations to // MarkFlagFilename instructs the various shell completion implementations to
// limit completions for the named flag to the specified file extensions. // limit completions for the named flag to the specified file extensions.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error { func (c *Command) MarkFlagFilename(name string, extensions ...string) error {

View file

@ -331,6 +331,24 @@ rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (re
rootCmd.MarkPersistentFlagRequired("region") rootCmd.MarkPersistentFlagRequired("region")
``` ```
### Multiple Required flags
If your command has multiple required flags that are not [grouped](#flag-groups) to report an error
when one or more flags have not been set, mark them as required:
```go
rootCmd.Flags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
rootCmd.Flags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
rootCmd.MarkFlagsRequired("region", "failoverRegion")
```
Or, for multiple persistent flags:
```go
rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
rootCmd.PersistentFlags().StringVarP(&Region, "failoverRegion", "f", "", "AWS failover region (required)")
rootCmd.MarkPersistentFlagsRequired("region", "failoverRegion")
```
### Flag Groups ### Flag Groups
If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then