This commit is contained in:
Nir Soffer 2024-05-03 18:27:07 +00:00 committed by GitHub
commit d247184674
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -23,9 +23,9 @@ import (
) )
const ( const (
requiredAsGroup = "cobra_annotation_required_if_others_set" requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required" oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive" mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
) )
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
@ -37,7 +37,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
if f == nil { if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
} }
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found. // Only errs if the flag isn't found.
panic(err) panic(err)
} }
@ -53,7 +53,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
if f == nil { if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
} }
if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil { if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found. // Only errs if the flag isn't found.
panic(err) panic(err)
} }
@ -70,7 +70,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
} }
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
panic(err) panic(err)
} }
} }
@ -91,9 +91,9 @@ func (c *Command) ValidateFlagGroups() error {
oneRequiredGroupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) { flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
}) })
if err := validateRequiredFlagGroups(groupStatus); err != nil { if err := validateRequiredFlagGroups(groupStatus); err != nil {
@ -232,9 +232,9 @@ func (c *Command) enforceFlagGroupsForCompletion() {
oneRequiredGroupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) { c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
}) })
// If a flag that is part of a group is present, we make all the other flags // If a flag that is part of a group is present, we make all the other flags