Added flags groups which are mutually exclusive but required

Fixed method naming
This commit is contained in:
Himadri Bhattacharjee 2022-07-02 03:34:23 +00:00
parent 69083f81b2
commit 15e2ae2ceb

View file

@ -24,6 +24,7 @@ import (
const ( const (
requiredAsGroup = "cobra_annotation_required_if_others_set" requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive" mutuallyExclusive = "cobra_annotation_mutually_exclusive"
mutuallyExclusiveRequired = "cobra_annotation_mutually_exclusive_required"
) )
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
@ -42,6 +43,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
} }
} }
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
// if the command is invoked with more than one flag from the given set of flags.
func (c *Command) MarkFlagsMutuallyExclusiveRequired(flagNames ...string) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive required flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveRequired, append(f.Annotations[mutuallyExclusiveRequired], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
// if the command is invoked with more than one flag from the given set of flags. // if the command is invoked with more than one flag from the given set of flags.
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
@ -71,9 +88,11 @@ func (c *Command) validateFlagGroups() error {
// then a map of each flag name and whether it is set or not. // then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{} groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) { flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus)
}) })
if err := validateRequiredFlagGroups(groupStatus); err != nil { if err := validateRequiredFlagGroups(groupStatus); err != nil {
@ -82,6 +101,9 @@ func (c *Command) validateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err return err
} }
if err := validateExclusiveRequiredFlagGroups(mutuallyExclusiveRequiredGroupStatus); err != nil {
return err
}
return nil return nil
} }
@ -162,6 +184,27 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil return nil
} }
func validateExclusiveRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
}
}
if len(set) == 1 {
continue
}
// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("only one flag in the group [%v] can be set; %v were all set", flagList, set)
}
return nil
}
func sortedKeys(m map[string]map[string]bool) []string { func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m)) keys := make([]string, len(m))
i := 0 i := 0
@ -185,9 +228,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
flags := c.Flags() flags := c.Flags()
groupStatus := map[string]map[string]bool{} groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) { c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus)
}) })
// If a flag that is part of a group is present, we make all the other flags // If a flag that is part of a group is present, we make all the other flags