From 15e2ae2ceb80457eb0ce970f87e127cebdf1debf Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee Date: Sat, 2 Jul 2022 03:34:23 +0000 Subject: [PATCH] Added flags groups which are mutually exclusive but required Fixed method naming --- flag_groups.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index dc784311..ae4b9f9c 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -22,8 +22,9 @@ import ( ) const ( - requiredAsGroup = "cobra_annotation_required_if_others_set" - mutuallyExclusive = "cobra_annotation_mutually_exclusive" + requiredAsGroup = "cobra_annotation_required_if_others_set" + mutuallyExclusive = "cobra_annotation_mutually_exclusive" + mutuallyExclusiveRequired = "cobra_annotation_mutually_exclusive_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -42,6 +43,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { } } +// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors +// if the command is invoked with more than one flag from the given set of flags. +func (c *Command) MarkFlagsMutuallyExclusiveRequired(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive required flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, mutuallyExclusiveRequired, append(f.Annotations[mutuallyExclusiveRequired], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors // if the command is invoked with more than one flag from the given set of flags. func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { @@ -71,9 +88,11 @@ func (c *Command) validateFlagGroups() error { // then a map of each flag name and whether it is set or not. groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -82,6 +101,9 @@ func (c *Command) validateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateExclusiveRequiredFlagGroups(mutuallyExclusiveRequiredGroupStatus); err != nil { + return err + } return nil } @@ -162,6 +184,27 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateExclusiveRequiredFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) == 1 { + continue + } + + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("only one flag in the group [%v] can be set; %v were all set", flagList, set) + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 @@ -185,9 +228,11 @@ func (c *Command) enforceFlagGroupsForCompletion() { flags := c.Flags() groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags