mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Added flags groups which are mutually exclusive but required
Fixed method naming
This commit is contained in:
parent
69083f81b2
commit
15e2ae2ceb
1 changed files with 47 additions and 2 deletions
|
@ -22,8 +22,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
requiredAsGroup = "cobra_annotation_required_if_others_set"
|
||||
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
|
||||
requiredAsGroup = "cobra_annotation_required_if_others_set"
|
||||
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
|
||||
mutuallyExclusiveRequired = "cobra_annotation_mutually_exclusive_required"
|
||||
)
|
||||
|
||||
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
||||
|
@ -42,6 +43,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
|
|||
}
|
||||
}
|
||||
|
||||
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
|
||||
// if the command is invoked with more than one flag from the given set of flags.
|
||||
func (c *Command) MarkFlagsMutuallyExclusiveRequired(flagNames ...string) {
|
||||
c.mergePersistentFlags()
|
||||
for _, v := range flagNames {
|
||||
f := c.Flags().Lookup(v)
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive required flag group", v))
|
||||
}
|
||||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
||||
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveRequired, append(f.Annotations[mutuallyExclusiveRequired], strings.Join(flagNames, " "))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
|
||||
// if the command is invoked with more than one flag from the given set of flags.
|
||||
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
||||
|
@ -71,9 +88,11 @@ func (c *Command) validateFlagGroups() error {
|
|||
// then a map of each flag name and whether it is set or not.
|
||||
groupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{}
|
||||
flags.VisitAll(func(pflag *flag.Flag) {
|
||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus)
|
||||
})
|
||||
|
||||
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
||||
|
@ -82,6 +101,9 @@ func (c *Command) validateFlagGroups() error {
|
|||
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateExclusiveRequiredFlagGroups(mutuallyExclusiveRequiredGroupStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -162,6 +184,27 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateExclusiveRequiredFlagGroups(data map[string]map[string]bool) error {
|
||||
keys := sortedKeys(data)
|
||||
for _, flagList := range keys {
|
||||
flagnameAndStatus := data[flagList]
|
||||
var set []string
|
||||
for flagname, isSet := range flagnameAndStatus {
|
||||
if isSet {
|
||||
set = append(set, flagname)
|
||||
}
|
||||
}
|
||||
if len(set) == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Sort values, so they can be tested/scripted against consistently.
|
||||
sort.Strings(set)
|
||||
return fmt.Errorf("only one flag in the group [%v] can be set; %v were all set", flagList, set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string]map[string]bool) []string {
|
||||
keys := make([]string, len(m))
|
||||
i := 0
|
||||
|
@ -185,9 +228,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
|
|||
flags := c.Flags()
|
||||
groupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveRequiredGroupStatus := map[string]map[string]bool{}
|
||||
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveRequired, mutuallyExclusiveRequiredGroupStatus)
|
||||
})
|
||||
|
||||
// If a flag that is part of a group is present, we make all the other flags
|
||||
|
|
Loading…
Reference in a new issue