From ca3972ed08a92875d11b81807a211024e9df331f Mon Sep 17 00:00:00 2001 From: faizan-siddiqui Date: Wed, 23 Oct 2024 16:51:02 +1100 Subject: [PATCH] add MarkIfFlagPresentThenOthersRequired --- flag_groups.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 560612fd..87ae7915 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,9 +23,10 @@ import ( ) const ( - requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" - oneRequiredAnnotation = "cobra_annotation_one_required" - mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" + requiredAsGroup = "cobra_annotation_required_if_others_set" + oneRequired = "cobra_annotation_one_required" + mutuallyExclusive = "cobra_annotation_mutually_exclusive" + ifPresentThenOthersRequired = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -76,6 +77,25 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } +// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set, +// all the other flags become required. +func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { + if len(flagNames) < 2 { + panic("MarkIfFlagPresentThenRequired requires at least two flags") + } + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequired, append(f.Annotations[ifPresentThenOthersRequired], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the // first error encountered. func (c *Command) ValidateFlagGroups() error { @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil { + return err + } return nil } @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error { + for flagList, flagnameAndStatus := range data { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + // Handle missing primary flag entry + if _, exists := flagnameAndStatus[primaryFlag]; !exists { + flagnameAndStatus[primaryFlag] = false + } + + // Check if the primary flag is set + if flagnameAndStatus[primaryFlag] { + var unset []string + for _, flag := range remainingFlags { + if !flagnameAndStatus[flag] { + unset = append(unset, flag) + } + } + + // If any dependent flags are unset, trigger an error + if len(unset) > 0 { + return fmt.Errorf( + "if the first flag in the group [%v] is set, all other flags must be set; the following flags are not set: %v", + flagList, unset, + ) + } + } + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // - when a flag in a group is present, other flags in the group will be marked required // - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden +// - when the first flag in an if-present-then-required group is present, the second flag will be marked as required // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { if c.DisableFlagParsing { @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } } + + // If a flag that is marked as if-present-then-required is present, make other flags in the group required + for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + if flagnameAndStatus[primaryFlag] { + for _, fName := range remainingFlags { + _ = c.MarkFlagRequired(fName) + } + } + } }