From c81c46a015b48d9f79c73cc6f528f90eaabb0c7e Mon Sep 17 00:00:00 2001 From: Martijn Evers <94963229+marevers@users.noreply.github.com> Date: Sun, 16 Jul 2023 18:38:22 +0200 Subject: [PATCH] Add 'one required flag' group (#1952) --- completions_test.go | 98 ++++++++++++++++++++++++++++++++++++++ flag_groups.go | 68 +++++++++++++++++++++++++- flag_groups_test.go | 63 ++++++++++++++++++++---- site/content/user_guide.md | 11 ++++- 4 files changed, 228 insertions(+), 12 deletions(-) diff --git a/completions_test.go b/completions_test.go index 0588da0f..01724660 100644 --- a/completions_test.go +++ b/completions_test.go @@ -2830,6 +2830,104 @@ func TestCompletionForGroupedFlags(t *testing.T) { } } +func TestCompletionForOneRequiredGroupFlags(t *testing.T) { + getCmd := func() *Command { + rootCmd := &Command{ + Use: "root", + Run: emptyRun, + } + childCmd := &Command{ + Use: "child", + ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"subArg"}, ShellCompDirectiveNoFileComp + }, + Run: emptyRun, + } + rootCmd.AddCommand(childCmd) + + rootCmd.PersistentFlags().Int("ingroup1", -1, "ingroup1") + rootCmd.PersistentFlags().String("ingroup2", "", "ingroup2") + + childCmd.Flags().Bool("ingroup3", false, "ingroup3") + childCmd.Flags().Bool("nogroup", false, "nogroup") + + // Add flags to a group + childCmd.MarkFlagsOneRequired("ingroup1", "ingroup2", "ingroup3") + + return rootCmd + } + + // Each test case uses a unique command from the function above. + testcases := []struct { + desc string + args []string + expectedOutput string + }{ + { + desc: "flags in group suggested without - prefix", + args: []string{"child", ""}, + expectedOutput: strings.Join([]string{ + "--ingroup1", + "--ingroup2", + "--ingroup3", + "subArg", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "flags in group suggested with - prefix", + args: []string{"child", "-"}, + expectedOutput: strings.Join([]string{ + "--ingroup1", + "--ingroup2", + "--ingroup3", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "when any flag in group present, other flags in group not suggested without - prefix", + args: []string{"child", "--ingroup2", "value", ""}, + expectedOutput: strings.Join([]string{ + "subArg", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "when all flags in group present, flags not suggested without - prefix", + args: []string{"child", "--ingroup1", "8", "--ingroup2", "value2", "--ingroup3", ""}, + expectedOutput: strings.Join([]string{ + "subArg", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + { + desc: "group ignored if some flags not applicable", + args: []string{"--ingroup2", "value", ""}, + expectedOutput: strings.Join([]string{ + "child", + "completion", + "help", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"), + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + c := getCmd() + args := []string{ShellCompNoDescRequestCmd} + args = append(args, tc.args...) + output, err := executeCommand(c, args...) + switch { + case err == nil && output != tc.expectedOutput: + t.Errorf("expected: %q, got: %q", tc.expectedOutput, output) + case err != nil: + t.Errorf("Unexpected error %q", err) + } + }) + } +} + func TestCompletionForMutuallyExclusiveFlags(t *testing.T) { getCmd := func() *Command { rootCmd := &Command{ diff --git a/flag_groups.go b/flag_groups.go index b35fde15..0671ec5f 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -24,6 +24,7 @@ import ( const ( requiredAsGroup = "cobra_annotation_required_if_others_set" + oneRequired = "cobra_annotation_one_required" mutuallyExclusive = "cobra_annotation_mutually_exclusive" ) @@ -43,6 +44,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { } } +// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors +// if the command is invoked without at least one flag from the given set of flags. +func (c *Command) MarkFlagsOneRequired(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) + } + if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil { + // Only errs if the flag isn't found. + panic(err) + } + } +} + // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors // if the command is invoked with more than one flag from the given set of flags. func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { @@ -59,7 +76,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } -// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the +// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the // first error encountered. func (c *Command) ValidateFlagGroups() error { if c.DisableFlagParsing { @@ -71,15 +88,20 @@ func (c *Command) ValidateFlagGroups() error { // groupStatus format is the list of flags as a unique ID, // then a map of each flag name and whether it is set or not. groupStatus := map[string]map[string]bool{} + oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { return err } + if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { + return err + } if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } @@ -142,6 +164,27 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { return nil } +func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) >= 1 { + continue + } + + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList) + } + return nil +} + func validateExclusiveFlagGroups(data map[string]map[string]bool) error { keys := sortedKeys(data) for _, flagList := range keys { @@ -176,6 +219,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // enforceFlagGroupsForCompletion will do the following: // - when a flag in a group is present, other flags in the group will be marked required +// - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { @@ -185,9 +229,11 @@ func (c *Command) enforceFlagGroupsForCompletion() { flags := c.Flags() groupStatus := map[string]map[string]bool{} + oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) }) @@ -204,6 +250,26 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } + // If none of the flags of a one-required group are present, we make all the flags + // of that group required so that the shell completion suggests them automatically + for flagList, flagnameAndStatus := range oneRequiredGroupStatus { + set := 0 + + for _, isSet := range flagnameAndStatus { + if isSet { + set++ + } + } + + // None of the flags of the group are set, mark all flags in the group + // as required + if set == 0 { + for _, fName := range strings.Split(flagList, " ") { + _ = c.MarkFlagRequired(fName) + } + } + } + // If a flag that is mutually exclusive to others is present, we hide the other // flags of that group so the shell completion does not suggest them for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { diff --git a/flag_groups_test.go b/flag_groups_test.go index bf988d73..cffa8552 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -43,13 +43,15 @@ func TestValidateFlagGroups(t *testing.T) { // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsOneRequired []string + flagGroupsExclusive []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsOneRequired []string + subCmdFlagGroupsExclusive []string + args []string + expectErr string }{ { desc: "No flags no problem", @@ -62,6 +64,11 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsRequired: []string{"a b c"}, args: []string{"--a=foo"}, expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + }, { + desc: "One-required flag group not satisfied", + flagGroupsOneRequired: []string{"a b"}, + args: []string{"--c=foo"}, + expectErr: "at least one of the flags in the group [a b] is required", }, { desc: "Exclusive flag group not satisfied", flagGroupsExclusive: []string{"a b c"}, @@ -72,6 +79,11 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsRequired: []string{"a b c", "a d"}, args: []string{"--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`, + }, { + desc: "Multiple one-required flag group not satisfied returns first error", + flagGroupsOneRequired: []string{"a b", "d e"}, + args: []string{"--c=foo", "--f=foo"}, + expectErr: `at least one of the flags in the group [a b] is required`, }, { desc: "Multiple exclusive flag group not satisfied returns first error", flagGroupsExclusive: []string{"a b c", "a d"}, @@ -82,32 +94,57 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsRequired: []string{"a d", "a b", "a c"}, args: []string{"--a=foo"}, expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, + }, { + desc: "Validation of one-required groups occurs on groups in sorted order", + flagGroupsOneRequired: []string{"d e", "a b", "f g"}, + args: []string{"--c=foo"}, + expectErr: `at least one of the flags in the group [a b] is required`, }, { desc: "Validation of exclusive groups occurs on groups in sorted order", flagGroupsExclusive: []string{"a d", "a b", "a c"}, args: []string{"--a=foo", "--b=foo", "--c=foo"}, expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, }, { - desc: "Persistent flags utilize both features and can fail required groups", + desc: "Persistent flags utilize required and exclusive groups and can fail required groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, args: []string{"--a=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, }, { - desc: "Persistent flags utilize both features and can fail mutually exclusive groups", + desc: "Persistent flags utilize one-required and exclusive groups and can fail one-required groups", + flagGroupsOneRequired: []string{"a b", "e f"}, + flagGroupsExclusive: []string{"e f"}, + args: []string{"--e=foo"}, + expectErr: `at least one of the flags in the group [a b] is required`, + }, { + desc: "Persistent flags utilize required and exclusive groups and can fail mutually exclusive groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, }, { - desc: "Persistent flags utilize both features and can pass", + desc: "Persistent flags utilize required and exclusive groups and can pass", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, args: []string{"--a=foo", "--e=foo", "--f=foo"}, + }, { + desc: "Persistent flags utilize one-required and exclusive groups and can pass", + flagGroupsOneRequired: []string{"a e", "e f"}, + flagGroupsExclusive: []string{"f g"}, + args: []string{"--a=foo", "--e=foo", "--f=foo"}, }, { desc: "Subcmds can use required groups using inherited flags", subCmdFlagGroupsRequired: []string{"e subonly"}, args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + }, { + desc: "Subcmds can use one-required groups using inherited flags", + subCmdFlagGroupsOneRequired: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + }, { + desc: "Subcmds can use one-required groups using inherited flags and fail one-required groups", + subCmdFlagGroupsOneRequired: []string{"e subonly"}, + args: []string{"subcmd"}, + expectErr: "at least one of the flags in the group [e subonly] is required", }, { desc: "Subcmds can use exclusive groups using inherited flags", subCmdFlagGroupsExclusive: []string{"e subonly"}, @@ -130,12 +167,18 @@ func TestValidateFlagGroups(t *testing.T) { for _, flagGroup := range tc.flagGroupsRequired { c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.flagGroupsOneRequired { + c.MarkFlagsOneRequired(strings.Split(flagGroup, " ")...) + } for _, flagGroup := range tc.flagGroupsExclusive { c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } for _, flagGroup := range tc.subCmdFlagGroupsRequired { sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.subCmdFlagGroupsOneRequired { + sub.MarkFlagsOneRequired(strings.Split(flagGroup, " ")...) + } for _, flagGroup := range tc.subCmdFlagGroupsExclusive { sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } diff --git a/site/content/user_guide.md b/site/content/user_guide.md index 04971431..56e8e44a 100644 --- a/site/content/user_guide.md +++ b/site/content/user_guide.md @@ -349,7 +349,16 @@ rootCmd.Flags().BoolVar(&ofYaml, "yaml", false, "Output in YAML") rootCmd.MarkFlagsMutuallyExclusive("json", "yaml") ``` -In both of these cases: +If you want to require at least one flag from a group to be present, you can use `MarkFlagsOneRequired`. +This can be combined with `MarkFlagsMutuallyExclusive` to enforce exactly one flag from a given group: +```go +rootCmd.Flags().BoolVar(&ofJson, "json", false, "Output in JSON") +rootCmd.Flags().BoolVar(&ofYaml, "yaml", false, "Output in YAML") +rootCmd.MarkFlagsOneRequired("json", "yaml") +rootCmd.MarkFlagsMutuallyExclusive("json", "yaml") +``` + +In these cases: - both local and persistent flags can be used - **NOTE:** the group is only enforced on commands where every flag is defined - a flag may appear in multiple groups