diff --git a/command.go b/command.go index 64f1d5f4..f4832147 100644 --- a/command.go +++ b/command.go @@ -146,6 +146,11 @@ type Command struct { // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName + // flagGroups is the list of groups that contain grouped names of flags. + // Groups are like "relationships" between flags that allow to validate + // flags and adjust completions taking into account these "relationships". + flagGroups []flagGroup + // usageFunc is usage func defined by user. usageFunc func(*Command) error // usageTemplate is usage template defined by user. diff --git a/completions.go b/completions.go index d8fa1f77..011745e6 100644 --- a/completions.go +++ b/completions.go @@ -351,8 +351,8 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi var completions []string var directive ShellCompDirective - // Enforce flag groups before doing flag completions - finalCmd.enforceFlagGroupsForCompletion() + // Allow flagGroups to update the command to improve completions + finalCmd.adjustByFlagGroupsForCompletions() // Note that we want to perform flagname completion even if finalCmd.DisableFlagParsing==true; // doing this allows for completion of persistent flag names even for commands that disable flag parsing. diff --git a/flag_groups.go b/flag_groups.go index 9c377aaf..52a77fc2 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -16,209 +16,181 @@ package cobra import ( "fmt" - "sort" - "strings" flag "github.com/spf13/pflag" ) -const ( - requiredAsGroup = "cobra_annotation_required_if_others_set" - mutuallyExclusive = "cobra_annotation_mutually_exclusive" -) - -// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors -// if the command is invoked with a subset (but not all) of the given flags. +// MarkFlagsRequiredTogether creates a relationship between flags, which ensures +// that if any of flags with names from flagNames is set, other flags must be set too. func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { - c.mergePersistentFlags() - for _, v := range flagNames { - f := c.Flags().Lookup(v) - if f == nil { - panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) - } - if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { - // Only errs if the flag isn't found. - panic(err) - } - } -} - -// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors -// if the command is invoked with more than one flag from the given set of flags. -func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { - c.mergePersistentFlags() - for _, v := range flagNames { - f := c.Flags().Lookup(v) - if f == nil { - panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) - } - // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { - panic(err) - } - } -} - -// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the -// first error encountered. -func (c *Command) ValidateFlagGroups() error { - if c.DisableFlagParsing { - return nil - } - - flags := c.Flags() - - // groupStatus format is the list of flags as a unique ID, - // then a map of each flag name and whether it is set or not. - groupStatus := map[string]map[string]bool{} - mutuallyExclusiveGroupStatus := map[string]map[string]bool{} - flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + c.addFlagGroup(&requiredTogetherFlagGroup{ + flagNames: flagNames, }) +} - if err := validateRequiredFlagGroups(groupStatus); err != nil { - return err +// MarkFlagsMutuallyExclusive creates a relationship between flags, which ensures +// that if any of flags with names from flagNames is set, other flags must not be set. +func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { + c.addFlagGroup(&mutuallyExclusiveFlagGroup{ + flagNames: flagNames, + }) +} + +// addFlagGroup merges persistent flags of the command and adds flagGroup into command's flagGroups list. +// Panics, if flagGroup g contains the name of the flag, which is not defined in the Command c. +func (c *Command) addFlagGroup(g flagGroup) { + c.mergePersistentFlags() + + for _, flagName := range g.AssignedFlagNames() { + if c.Flags().Lookup(flagName) == nil { + panic(fmt.Sprintf("flag %q is not defined", flagName)) + } } - if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { - return err + + c.flagGroups = append(c.flagGroups, g) +} + +// ValidateFlagGroups runs validation for each group from command's flagGroups list, +// and returns the first error encountered, or nil, if there were no validation errors. +func (c *Command) ValidateFlagGroups() error { + setFlags := makeSetFlagsSet(c.Flags()) + for _, group := range c.flagGroups { + if err := group.ValidateSetFlags(setFlags); err != nil { + return err + } } return nil } -func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { - for _, fname := range flagnames { - f := fs.Lookup(fname) - if f == nil { - return false - } - } - return true -} - -func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { - groupInfo, found := pflag.Annotations[annotation] - if found { - for _, group := range groupInfo { - if groupStatus[group] == nil { - flagnames := strings.Split(group, " ") - - // Only consider this flag group at all if all the flags are defined. - if !hasAllFlags(flags, flagnames...) { - continue - } - - groupStatus[group] = map[string]bool{} - for _, name := range flagnames { - groupStatus[group][name] = false - } - } - - groupStatus[group][pflag.Name] = pflag.Changed - } - } -} - -func validateRequiredFlagGroups(data map[string]map[string]bool) error { - keys := sortedKeys(data) - for _, flagList := range keys { - flagnameAndStatus := data[flagList] - - unset := []string{} - for flagname, isSet := range flagnameAndStatus { - if !isSet { - unset = append(unset, flagname) - } - } - if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { - continue - } - - // Sort values, so they can be tested/scripted against consistently. - sort.Strings(unset) - return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) - } - - return nil -} - -func validateExclusiveFlagGroups(data map[string]map[string]bool) error { - keys := sortedKeys(data) - for _, flagList := range keys { - flagnameAndStatus := data[flagList] - var set []string - for flagname, isSet := range flagnameAndStatus { - if isSet { - set = append(set, flagname) - } - } - if len(set) == 0 || len(set) == 1 { - continue - } - - // Sort values, so they can be tested/scripted against consistently. - sort.Strings(set) - return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) - } - return nil -} - -func sortedKeys(m map[string]map[string]bool) []string { - keys := make([]string, len(m)) - i := 0 - for k := range m { - keys[i] = k - i++ - } - sort.Strings(keys) - return keys -} - -// enforceFlagGroupsForCompletion will do the following: -// - when a flag in a group is present, other flags in the group will be marked required -// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden -// This allows the standard completion logic to behave appropriately for flag groups -func (c *Command) enforceFlagGroupsForCompletion() { +// adjustByFlagGroupsForCompletions changes the command by each flagGroup from command's flagGroups list +// to make the further command completions generation more convenient. +// Does nothing, if Command.DisableFlagParsing is true. +func (c *Command) adjustByFlagGroupsForCompletions() { if c.DisableFlagParsing { return } - flags := c.Flags() - groupStatus := map[string]map[string]bool{} - mutuallyExclusiveGroupStatus := map[string]map[string]bool{} - c.Flags().VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) - }) + for _, group := range c.flagGroups { + group.AdjustCommandForCompletions(c) + } +} - // If a flag that is part of a group is present, we make all the other flags - // of that group required so that the shell completion suggests them automatically - for flagList, flagnameAndStatus := range groupStatus { - for _, isSet := range flagnameAndStatus { - if isSet { - // One of the flags of the group is set, mark the other ones as required - for _, fName := range strings.Split(flagList, " ") { - _ = c.MarkFlagRequired(fName) - } - } - } +type flagGroup interface { + // ValidateSetFlags checks whether the combination of flags that have been set is valid. + // If not, an error is returned. + ValidateSetFlags(setFlags setFlagsSet) error + + // AssignedFlagNames returns a full list of flag names that have been assigned to the group. + AssignedFlagNames() []string + + // AdjustCommandForCompletions updates the command to generate more convenient for this group completions. + AdjustCommandForCompletions(c *Command) +} + +// requiredTogetherFlagGroup groups flags that are required together and +// must all be set, if any of flags from this group is set. +type requiredTogetherFlagGroup struct { + flagNames []string +} + +func (g *requiredTogetherFlagGroup) AssignedFlagNames() []string { + return g.flagNames +} +func (g *requiredTogetherFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error { + unset := setFlags.selectUnsetFlagNamesFrom(g.flagNames) + + if unsetCount := len(unset); unsetCount != 0 && unsetCount != len(g.flagNames) { + return fmt.Errorf("flags %v must be set together, but %v were not set", g.flagNames, unset) } - // If a flag that is mutually exclusive to others is present, we hide the other - // flags of that group so the shell completion does not suggest them - for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { - for flagName, isSet := range flagnameAndStatus { - if isSet { - // One of the flags of the mutually exclusive group is set, mark the other ones as hidden - // Don't mark the flag that is already set as hidden because it may be an - // array or slice flag and therefore must continue being suggested - for _, fName := range strings.Split(flagList, " ") { - if fName != flagName { - flag := c.Flags().Lookup(fName) - flag.Hidden = true - } - } + return nil +} +func (g *requiredTogetherFlagGroup) AdjustCommandForCompletions(c *Command) { + setFlags := makeSetFlagsSet(c.Flags()) + if setFlags.hasAnyFrom(g.flagNames) { + for _, requiredFlagName := range g.flagNames { + _ = c.MarkFlagRequired(requiredFlagName) + } + } +} + +// mutuallyExclusiveFlagGroup groups flags that are mutually exclusive +// and must not be set together, if any of flags from this group is set. +type mutuallyExclusiveFlagGroup struct { + flagNames []string +} + +func (g *mutuallyExclusiveFlagGroup) AssignedFlagNames() []string { + return g.flagNames +} +func (g *mutuallyExclusiveFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error { + set := setFlags.selectSetFlagNamesFrom(g.flagNames) + + if len(set) > 1 { + return fmt.Errorf("exactly one of the flags %v can be set, but %v were set", g.flagNames, set) + } + return nil +} +func (g *mutuallyExclusiveFlagGroup) AdjustCommandForCompletions(c *Command) { + setFlags := makeSetFlagsSet(c.Flags()) + firstSetFlagName, hasAny := setFlags.selectFirstSetFlagNameFrom(g.flagNames) + if hasAny { + for _, exclusiveFlagName := range g.flagNames { + if exclusiveFlagName != firstSetFlagName { + c.Flags().Lookup(exclusiveFlagName).Hidden = true } } } } + +// setFlagsSet is a helper set type that is intended to be used to store names of the flags +// that have been set in flag.FlagSet and to perform some lookups and checks on those flags. +type setFlagsSet map[string]struct{} + +// makeSetFlagsSet creates setFlagsSet of names of the flags that have been set in the given flag.FlagSet. +func makeSetFlagsSet(fs *flag.FlagSet) setFlagsSet { + s := make(setFlagsSet) + + // Visit flags that have been set and add them to the set + fs.Visit(func(f *flag.Flag) { + s[f.Name] = struct{}{} + }) + + return s +} +func (s setFlagsSet) has(flagName string) bool { + _, ok := s[flagName] + return ok +} +func (s setFlagsSet) hasAnyFrom(flagNames []string) bool { + for _, flagName := range flagNames { + if s.has(flagName) { + return true + } + } + return false +} +func (s setFlagsSet) selectFirstSetFlagNameFrom(flagNames []string) (string, bool) { + for _, flagName := range flagNames { + if s.has(flagName) { + return flagName, true + } + } + return "", false +} +func (s setFlagsSet) selectSetFlagNamesFrom(flagNames []string) (setFlagNames []string) { + for _, flagName := range flagNames { + if s.has(flagName) { + setFlagNames = append(setFlagNames, flagName) + } + } + return +} +func (s setFlagsSet) selectUnsetFlagNamesFrom(flagNames []string) (unsetFlagNames []string) { + for _, flagName := range flagNames { + if !s.has(flagName) { + unsetFlagNames = append(unsetFlagNames, flagName) + } + } + return +} diff --git a/flag_groups_test.go b/flag_groups_test.go index b4b65ac0..ff9a85d6 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -21,126 +21,137 @@ import ( func TestValidateFlagGroups(t *testing.T) { getCmd := func() *Command { - c := &Command{ + cmd := &Command{ Use: "testcmd", - Run: func(cmd *Command, args []string) { - }} - // Define lots of flags to utilize for testing. - for _, v := range []string{"a", "b", "c", "d"} { - c.Flags().String(v, "", "") + Run: func(cmd *Command, args []string) {}, } - for _, v := range []string{"e", "f", "g"} { - c.PersistentFlags().String(v, "", "") - } - subC := &Command{ + + cmd.Flags().String("a", "", "") + cmd.Flags().String("b", "", "") + cmd.Flags().String("c", "", "") + cmd.Flags().String("d", "", "") + cmd.PersistentFlags().String("p-a", "", "") + cmd.PersistentFlags().String("p-b", "", "") + cmd.PersistentFlags().String("p-c", "", "") + + subCmd := &Command{ Use: "subcmd", - Run: func(cmd *Command, args []string) { - }} - subC.Flags().String("subonly", "", "") - c.AddCommand(subC) - return c + Run: func(cmd *Command, args []string) {}, + } + subCmd.Flags().String("sub-a", "", "") + + cmd.AddCommand(subCmd) + + return cmd } // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + requiredTogether []string + mutuallyExclusive []string + subRequiredTogether []string + subMutuallyExclusive []string + args []string + expectErr string }{ { - desc: "No flags no problem", + desc: "No flags no problems", }, { - desc: "No flags no problem even with conflicting groups", - flagGroupsRequired: []string{"a b"}, - flagGroupsExclusive: []string{"a b"}, + desc: "No flags no problems even with conflicting groups", + requiredTogether: []string{"a b"}, + mutuallyExclusive: []string{"a b"}, }, { - desc: "Required flag group not satisfied", - flagGroupsRequired: []string{"a b c"}, - args: []string{"--a=foo"}, - expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + desc: "Required together flag group validation fails", + requiredTogether: []string{"a b c"}, + args: []string{"--a=foo"}, + expectErr: `flags [a b c] must be set together, but [b c] were not set`, }, { - desc: "Exclusive flag group not satisfied", - flagGroupsExclusive: []string{"a b c"}, - args: []string{"--a=foo", "--b=foo"}, - expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + desc: "Required together flag group validation passes", + requiredTogether: []string{"a b c"}, + args: []string{"--c=bar", "--a=foo", "--b=baz"}, }, { - desc: "Multiple required flag group not satisfied returns first error", - flagGroupsRequired: []string{"a b c", "a d"}, - args: []string{"--c=foo", "--d=foo"}, - expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`, + desc: "Mutually exclusive flag group validation fails", + mutuallyExclusive: []string{"a b c"}, + args: []string{"--b=foo", "--c=bar"}, + expectErr: `exactly one of the flags [a b c] can be set, but [b c] were set`, }, { - desc: "Multiple exclusive flag group not satisfied returns first error", - flagGroupsExclusive: []string{"a b c", "a d"}, - args: []string{"--a=foo", "--c=foo", "--d=foo"}, - expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, + desc: "Mutually exclusive flag group validation passes", + mutuallyExclusive: []string{"a b c"}, + args: []string{"--b=foo"}, }, { - desc: "Validation of required groups occurs on groups in sorted order", - flagGroupsRequired: []string{"a d", "a b", "a c"}, - args: []string{"--a=foo"}, - expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, + desc: "Multiple required together flag groups failed validation returns first error", + requiredTogether: []string{"a b c", "a d"}, + args: []string{"--d=foo", "--c=foo"}, + expectErr: `flags [a b c] must be set together, but [a b] were not set`, }, { - desc: "Validation of exclusive groups occurs on groups in sorted order", - flagGroupsExclusive: []string{"a d", "a b", "a c"}, - args: []string{"--a=foo", "--b=foo", "--c=foo"}, - expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, + desc: "Multiple mutually exclusive flag groups failed validation returns first error", + mutuallyExclusive: []string{"a b c", "a d"}, + args: []string{"--a=foo", "--c=foo", "--d=foo"}, + expectErr: `exactly one of the flags [a b c] can be set, but [a c] were set`, }, { - desc: "Persistent flags utilize both features and can fail required groups", - flagGroupsRequired: []string{"a e", "e f"}, - flagGroupsExclusive: []string{"f g"}, - args: []string{"--a=foo", "--f=foo", "--g=foo"}, - expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, + desc: "Flag and persistent flags being in multiple groups fail required together group", + requiredTogether: []string{"a p-a", "p-a p-b"}, + mutuallyExclusive: []string{"p-b p-c"}, + args: []string{"--a=foo", "--p-b=foo", "--p-c=foo"}, + expectErr: `flags [a p-a] must be set together, but [p-a] were not set`, }, { - desc: "Persistent flags utilize both features and can fail mutually exclusive groups", - flagGroupsRequired: []string{"a e", "e f"}, - flagGroupsExclusive: []string{"f g"}, - args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, - expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, + desc: "Flag and persistent flags being in multiple groups fail mutually exclusive group", + requiredTogether: []string{"a p-a", "p-a p-b"}, + mutuallyExclusive: []string{"p-b p-c"}, + args: []string{"--a=foo", "--p-a=foo", "--p-b=foo", "--p-c=foo"}, + expectErr: `exactly one of the flags [p-b p-c] can be set, but [p-b p-c] were set`, }, { - desc: "Persistent flags utilize both features and can pass", - flagGroupsRequired: []string{"a e", "e f"}, - flagGroupsExclusive: []string{"f g"}, - args: []string{"--a=foo", "--e=foo", "--f=foo"}, + desc: "Flag and persistent flags pass required together and mutually exclusive groups", + requiredTogether: []string{"a p-a", "p-a p-b"}, + mutuallyExclusive: []string{"p-b p-c"}, + args: []string{"--a=foo", "--p-a=foo", "--p-b=foo"}, }, { - desc: "Subcmds can use required groups using inherited flags", - subCmdFlagGroupsRequired: []string{"e subonly"}, - args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + desc: "Required together flag group validation fails on subcommand with inherited flag", + subRequiredTogether: []string{"p-a sub-a"}, + args: []string{"subcmd", "--sub-a=foo"}, + expectErr: `flags [p-a sub-a] must be set together, but [p-a] were not set`, }, { - desc: "Subcmds can use exclusive groups using inherited flags", - subCmdFlagGroupsExclusive: []string{"e subonly"}, - args: []string{"subcmd", "--e=foo", "--subonly=foo"}, - expectErr: "if any flags in the group [e subonly] are set none of the others can be; [e subonly] were all set", + desc: "Required together flag group validation passes on subcommand with inherited flag", + subRequiredTogether: []string{"p-a sub-a"}, + args: []string{"subcmd", "--p-a=foo", "--sub-a=foo"}, }, { - desc: "Subcmds can use exclusive groups using inherited flags and pass", - subCmdFlagGroupsExclusive: []string{"e subonly"}, - args: []string{"subcmd", "--e=foo"}, + desc: "Mutually exclusive flag group validation fails on subcommand with inherited flag", + subMutuallyExclusive: []string{"p-a sub-a"}, + args: []string{"subcmd", "--p-a=foo", "--sub-a=foo"}, + expectErr: `exactly one of the flags [p-a sub-a] can be set, but [p-a sub-a] were set`, }, { - desc: "Flag groups not applied if not found on invoked command", - subCmdFlagGroupsRequired: []string{"e subonly"}, - args: []string{"--e=foo"}, + desc: "Mutually exclusive flag group validation passes on subcommand with inherited flag", + subMutuallyExclusive: []string{"p-a sub-a"}, + args: []string{"subcmd", "--p-a=foo"}, + }, { + desc: "Required together flag group validation is not applied on other command", + subRequiredTogether: []string{"p-a sub-a"}, + args: []string{"--p-a=foo"}, }, } + for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { - c := getCmd() - sub := c.Commands()[0] - for _, flagGroup := range tc.flagGroupsRequired { - c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + cmd := getCmd() + subCmd := cmd.Commands()[0] + + for _, group := range tc.requiredTogether { + cmd.MarkFlagsRequiredTogether(strings.Split(group, " ")...) } - for _, flagGroup := range tc.flagGroupsExclusive { - c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + for _, group := range tc.mutuallyExclusive { + cmd.MarkFlagsMutuallyExclusive(strings.Split(group, " ")...) } - for _, flagGroup := range tc.subCmdFlagGroupsRequired { - sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + for _, group := range tc.subRequiredTogether { + subCmd.MarkFlagsRequiredTogether(strings.Split(group, " ")...) } - for _, flagGroup := range tc.subCmdFlagGroupsExclusive { - sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + for _, group := range tc.subMutuallyExclusive { + subCmd.MarkFlagsMutuallyExclusive(strings.Split(group, " ")...) } - c.SetArgs(tc.args) - err := c.Execute() + + cmd.SetArgs(tc.args) + err := cmd.Execute() + switch { case err == nil && len(tc.expectErr) > 0: t.Errorf("Expected error %q but got nil", tc.expectErr)