mirror of
https://github.com/spf13/cobra
synced 2024-11-24 22:57:12 +00:00
add MarkIfFlagPresentThenOthersRequired
This commit is contained in:
parent
3a5efaede9
commit
ca3972ed08
1 changed files with 76 additions and 3 deletions
|
@ -23,9 +23,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
|
requiredAsGroup = "cobra_annotation_required_if_others_set"
|
||||||
oneRequiredAnnotation = "cobra_annotation_one_required"
|
oneRequired = "cobra_annotation_one_required"
|
||||||
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
|
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
|
||||||
|
ifPresentThenOthersRequired = "cobra_annotation_if_present_then_others_required"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
||||||
|
@ -76,6 +77,25 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
|
||||||
|
// all the other flags become required.
|
||||||
|
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
|
||||||
|
if len(flagNames) < 2 {
|
||||||
|
panic("MarkIfFlagPresentThenRequired requires at least two flags")
|
||||||
|
}
|
||||||
|
c.mergePersistentFlags()
|
||||||
|
for _, v := range flagNames {
|
||||||
|
f := c.Flags().Lookup(v)
|
||||||
|
if f == nil {
|
||||||
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
|
||||||
|
}
|
||||||
|
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
||||||
|
if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequired, append(f.Annotations[ifPresentThenOthersRequired], strings.Join(flagNames, " "))); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
|
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
|
||||||
// first error encountered.
|
// first error encountered.
|
||||||
func (c *Command) ValidateFlagGroups() error {
|
func (c *Command) ValidateFlagGroups() error {
|
||||||
|
@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error {
|
||||||
groupStatus := map[string]map[string]bool{}
|
groupStatus := map[string]map[string]bool{}
|
||||||
oneRequiredGroupStatus := map[string]map[string]bool{}
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
||||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||||
|
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
|
||||||
flags.VisitAll(func(pflag *flag.Flag) {
|
flags.VisitAll(func(pflag *flag.Flag) {
|
||||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
||||||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
||||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
||||||
|
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenOthersRequiredGroupStatus)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
||||||
|
@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error {
|
||||||
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error {
|
||||||
|
for flagList, flagnameAndStatus := range data {
|
||||||
|
flags := strings.Split(flagList, " ")
|
||||||
|
primaryFlag := flags[0]
|
||||||
|
remainingFlags := flags[1:]
|
||||||
|
|
||||||
|
// Handle missing primary flag entry
|
||||||
|
if _, exists := flagnameAndStatus[primaryFlag]; !exists {
|
||||||
|
flagnameAndStatus[primaryFlag] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the primary flag is set
|
||||||
|
if flagnameAndStatus[primaryFlag] {
|
||||||
|
var unset []string
|
||||||
|
for _, flag := range remainingFlags {
|
||||||
|
if !flagnameAndStatus[flag] {
|
||||||
|
unset = append(unset, flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any dependent flags are unset, trigger an error
|
||||||
|
if len(unset) > 0 {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"if the first flag in the group [%v] is set, all other flags must be set; the following flags are not set: %v",
|
||||||
|
flagList, unset,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func sortedKeys(m map[string]map[string]bool) []string {
|
func sortedKeys(m map[string]map[string]bool) []string {
|
||||||
keys := make([]string, len(m))
|
keys := make([]string, len(m))
|
||||||
i := 0
|
i := 0
|
||||||
|
@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
|
||||||
// - when a flag in a group is present, other flags in the group will be marked required
|
// - when a flag in a group is present, other flags in the group will be marked required
|
||||||
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
|
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
|
||||||
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
|
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
|
||||||
|
// - when the first flag in an if-present-then-required group is present, the second flag will be marked as required
|
||||||
// This allows the standard completion logic to behave appropriately for flag groups
|
// This allows the standard completion logic to behave appropriately for flag groups
|
||||||
func (c *Command) enforceFlagGroupsForCompletion() {
|
func (c *Command) enforceFlagGroupsForCompletion() {
|
||||||
if c.DisableFlagParsing {
|
if c.DisableFlagParsing {
|
||||||
|
@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
|
||||||
groupStatus := map[string]map[string]bool{}
|
groupStatus := map[string]map[string]bool{}
|
||||||
oneRequiredGroupStatus := map[string]map[string]bool{}
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
||||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||||
|
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
|
||||||
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
||||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
||||||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
||||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
||||||
|
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenRequiredGroupStatus)
|
||||||
})
|
})
|
||||||
|
|
||||||
// If a flag that is part of a group is present, we make all the other flags
|
// If a flag that is part of a group is present, we make all the other flags
|
||||||
|
@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a flag that is marked as if-present-then-required is present, make other flags in the group required
|
||||||
|
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus {
|
||||||
|
flags := strings.Split(flagList, " ")
|
||||||
|
primaryFlag := flags[0]
|
||||||
|
remainingFlags := flags[1:]
|
||||||
|
|
||||||
|
if flagnameAndStatus[primaryFlag] {
|
||||||
|
for _, fName := range remainingFlags {
|
||||||
|
_ = c.MarkFlagRequired(fName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue