From ca3972ed08a92875d11b81807a211024e9df331f Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Wed, 23 Oct 2024 16:51:02 +1100 Subject: [PATCH 1/8] add MarkIfFlagPresentThenOthersRequired --- flag_groups.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 560612fd..87ae7915 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,9 +23,10 @@ import ( ) const ( - requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" - oneRequiredAnnotation = "cobra_annotation_one_required" - mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" + requiredAsGroup = "cobra_annotation_required_if_others_set" + oneRequired = "cobra_annotation_one_required" + mutuallyExclusive = "cobra_annotation_mutually_exclusive" + ifPresentThenOthersRequired = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -76,6 +77,25 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } +// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set, +// all the other flags become required. +func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { + if len(flagNames) < 2 { + panic("MarkIfFlagPresentThenRequired requires at least two flags") + } + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequired, append(f.Annotations[ifPresentThenOthersRequired], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the // first error encountered. func (c *Command) ValidateFlagGroups() error { @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil { + return err + } return nil } @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error { + for flagList, flagnameAndStatus := range data { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + // Handle missing primary flag entry + if _, exists := flagnameAndStatus[primaryFlag]; !exists { + flagnameAndStatus[primaryFlag] = false + } + + // Check if the primary flag is set + if flagnameAndStatus[primaryFlag] { + var unset []string + for _, flag := range remainingFlags { + if !flagnameAndStatus[flag] { + unset = append(unset, flag) + } + } + + // If any dependent flags are unset, trigger an error + if len(unset) > 0 { + return fmt.Errorf( + "if the first flag in the group [%v] is set, all other flags must be set; the following flags are not set: %v", + flagList, unset, + ) + } + } + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // - when a flag in a group is present, other flags in the group will be marked required // - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden +// - when the first flag in an if-present-then-required group is present, the second flag will be marked as required // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { if c.DisableFlagParsing { @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } } + + // If a flag that is marked as if-present-then-required is present, make other flags in the group required + for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + if flagnameAndStatus[primaryFlag] { + for _, fName := range remainingFlags { + _ = c.MarkFlagRequired(fName) + } + } + } } From 49443501edb6f459eb4453fc7b7d20c502ec9962 Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Wed, 23 Oct 2024 17:00:26 +1100 Subject: [PATCH 2/8] added test cases --- flag_groups_test.go | 49 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/flag_groups_test.go b/flag_groups_test.go index cffa8552..a1c79ce3 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) { // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsOneRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsOneRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsOneRequired []string + flagGroupsExclusive []string + flagGroupsIfPresentThenRequired []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsOneRequired []string + subCmdFlagGroupsExclusive []string + subCmdFlagGroupsIfPresentThenRequired []string + args []string + expectErr string }{ { desc: "No flags no problem", }, { - desc: "No flags no problem even with conflicting groups", - flagGroupsRequired: []string{"a b"}, - flagGroupsExclusive: []string{"a b"}, + desc: "No flags no problem even with conflicting groups", + flagGroupsRequired: []string{"a b"}, + flagGroupsExclusive: []string{"a b"}, + flagGroupsIfPresentThenRequired: []string{"a b"}, }, { desc: "Required flag group not satisfied", flagGroupsRequired: []string{"a b c"}, @@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a b c"}, args: []string{"--a=foo", "--b=foo"}, expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "If present then others required flag group not satisfied", + flagGroupsIfPresentThenRequired: []string{"a b"}, + args: []string{"--a=foo"}, + expectErr: "if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]", }, { desc: "Multiple required flag group not satisfied returns first error", flagGroupsRequired: []string{"a b c", "a d"}, @@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a b c", "a d"}, args: []string{"--a=foo", "--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, + }, + { + desc: "Multiple if present then others required flag group not satisfied returns first error", + flagGroupsIfPresentThenRequired: []string{"a b", "d e"}, + args: []string{"--a=foo", "--f=foo"}, + expectErr: `if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]`, }, { desc: "Validation of required groups occurs on groups in sorted order", flagGroupsRequired: []string{"a d", "a b", "a c"}, @@ -104,6 +118,11 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a d", "a b", "a c"}, args: []string{"--a=foo", "--b=foo", "--c=foo"}, expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, + }, { + desc: "Validation of if present then others required groups occurs on groups in sorted order", + flagGroupsIfPresentThenRequired: []string{"a d", "a b", "a c"}, + args: []string{"--a=foo"}, + expectErr: `if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]`, }, { desc: "Persistent flags utilize required and exclusive groups and can fail required groups", flagGroupsRequired: []string{"a e", "e f"}, @@ -182,6 +201,12 @@ func TestValidateFlagGroups(t *testing.T) { for _, flagGroup := range tc.subCmdFlagGroupsExclusive { sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.flagGroupsIfPresentThenRequired { + c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired { + sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...) + } c.SetArgs(tc.args) err := c.Execute() switch { From 23cadf7ab4524c81986a9324de19a1b6dac51ad7 Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Wed, 23 Oct 2024 17:16:12 +1100 Subject: [PATCH 3/8] update test --- flag_groups_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/flag_groups_test.go b/flag_groups_test.go index a1c79ce3..e704a92f 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -118,11 +118,6 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a d", "a b", "a c"}, args: []string{"--a=foo", "--b=foo", "--c=foo"}, expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, - }, { - desc: "Validation of if present then others required groups occurs on groups in sorted order", - flagGroupsIfPresentThenRequired: []string{"a d", "a b", "a c"}, - args: []string{"--a=foo"}, - expectErr: `if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]`, }, { desc: "Persistent flags utilize required and exclusive groups and can fail required groups", flagGroupsRequired: []string{"a e", "e f"}, From ad5e8c357903f0a238c40d89e5b06ddad729e3d7 Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Thu, 16 Jan 2025 13:24:47 +1100 Subject: [PATCH 4/8] update error message to be more readable Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --- flag_groups.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 87ae7915..df0cc128 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -254,8 +254,8 @@ func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) er // If any dependent flags are unset, trigger an error if len(unset) > 0 { return fmt.Errorf( - "if the first flag in the group [%v] is set, all other flags must be set; the following flags are not set: %v", - flagList, unset, + "%v is set, the following flags must be provided: %v", + flagList[0], unset, ) } } From cb45935c4b2ce59b0bdfd0259459762ad1473802 Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Thu, 16 Jan 2025 13:25:06 +1100 Subject: [PATCH 5/8] add more verbose test cases Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --- flag_groups_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flag_groups_test.go b/flag_groups_test.go index e704a92f..8178603f 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -61,7 +61,7 @@ func TestValidateFlagGroups(t *testing.T) { desc: "No flags no problem even with conflicting groups", flagGroupsRequired: []string{"a b"}, flagGroupsExclusive: []string{"a b"}, - flagGroupsIfPresentThenRequired: []string{"a b"}, + flagGroupsIfPresentThenRequired: []string{"a b", "b a"}, }, { desc: "Required flag group not satisfied", flagGroupsRequired: []string{"a b c"}, From 1c44c42794b98678896beaa00fb0334d6a947bca Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Thu, 16 Jan 2025 13:25:44 +1100 Subject: [PATCH 6/8] fix comment to say other flags will be marked required instead of second flag Co-authored-by: ccoVeille <3875889+ccoVeille@users.noreply.github.com> --- flag_groups.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flag_groups.go b/flag_groups.go index df0cc128..4a0875bd 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -278,7 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // - when a flag in a group is present, other flags in the group will be marked required // - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden -// - when the first flag in an if-present-then-required group is present, the second flag will be marked as required +// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { if c.DisableFlagParsing { From dd7bc42aa87314192b6926dd5ab7f4d9dd06f451 Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Thu, 16 Jan 2025 02:39:38 +0000 Subject: [PATCH 7/8] fix test error message --- flag_groups.go | 18 +++++++++--------- flag_groups_test.go | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 4a0875bd..49721430 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,10 +23,10 @@ import ( ) const ( - requiredAsGroup = "cobra_annotation_required_if_others_set" - oneRequired = "cobra_annotation_one_required" - mutuallyExclusive = "cobra_annotation_mutually_exclusive" - ifPresentThenOthersRequired = "cobra_annotation_if_present_then_others_required" + requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" + oneRequiredAnnotation = "cobra_annotation_one_required" + mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" + ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -90,7 +90,7 @@ func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequired, append(f.Annotations[ifPresentThenOthersRequired], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequiredAnnotation, append(f.Annotations[ifPresentThenOthersRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -115,7 +115,7 @@ func (c *Command) ValidateFlagGroups() error { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenOthersRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -255,7 +255,7 @@ func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) er if len(unset) > 0 { return fmt.Errorf( "%v is set, the following flags must be provided: %v", - flagList[0], unset, + primaryFlag, unset, ) } } @@ -294,7 +294,7 @@ func (c *Command) enforceFlagGroupsForCompletion() { processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags @@ -347,7 +347,7 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } } - + // If a flag that is marked as if-present-then-required is present, make other flags in the group required for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus { flags := strings.Split(flagList, " ") diff --git a/flag_groups_test.go b/flag_groups_test.go index 8178603f..158a138c 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -81,7 +81,7 @@ func TestValidateFlagGroups(t *testing.T) { desc: "If present then others required flag group not satisfied", flagGroupsIfPresentThenRequired: []string{"a b"}, args: []string{"--a=foo"}, - expectErr: "if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]", + expectErr: "a is set, the following flags must be provided: [b]", }, { desc: "Multiple required flag group not satisfied returns first error", flagGroupsRequired: []string{"a b c", "a d"}, @@ -102,7 +102,7 @@ func TestValidateFlagGroups(t *testing.T) { desc: "Multiple if present then others required flag group not satisfied returns first error", flagGroupsIfPresentThenRequired: []string{"a b", "d e"}, args: []string{"--a=foo", "--f=foo"}, - expectErr: `if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]`, + expectErr: `a is set, the following flags must be provided: [b]`, }, { desc: "Validation of required groups occurs on groups in sorted order", flagGroupsRequired: []string{"a d", "a b", "a c"}, From 578b0744b26ea16661fd7a59ab1eda64d03dd19d Mon Sep 17 00:00:00 2001 From: faizan-siddiqui <siddiqui.faizan96@gmail.com> Date: Tue, 18 Feb 2025 05:17:53 +0000 Subject: [PATCH 8/8] add explicit test for testing annotation --- flag_groups.go | 32 ++++++++++----------- flag_groups_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/flag_groups.go b/flag_groups.go index 49721430..62eb6404 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,10 +23,10 @@ import ( ) const ( - requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" - oneRequiredAnnotation = "cobra_annotation_one_required" - mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" - ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required" + annotationGroupRequired = "cobra_annotation_required_if_others_set" + annotationRequiredOne = "cobra_annotation_one_required" + annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive" + annotationGroupDependent = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -38,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) } - if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -54,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) } - if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -71,7 +71,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -90,7 +90,7 @@ func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequiredAnnotation, append(f.Annotations[ifPresentThenOthersRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -112,10 +112,10 @@ func (c *Command) ValidateFlagGroups() error { mutuallyExclusiveGroupStatus := map[string]map[string]bool{} ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenOthersRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -291,10 +291,10 @@ func (c *Command) enforceFlagGroupsForCompletion() { mutuallyExclusiveGroupStatus := map[string]map[string]bool{} ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags diff --git a/flag_groups_test.go b/flag_groups_test.go index 158a138c..7b602bc9 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -213,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) { }) } } + +func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) { + // Create a new command with some flags. + cmd := &Command{ + Use: "testcmd", + } + f := cmd.Flags() + f.String("a", "", "flag a") + f.String("b", "", "flag b") + f.String("c", "", "flag c") + + // Call the function with one group: ["a", "b"]. + cmd.MarkIfFlagPresentThenOthersRequired("a", "b") + + // Check that flag "a" has the correct annotation. + aFlag := f.Lookup("a") + if aFlag == nil { + t.Fatal("Flag 'a' not found") + } + annA := aFlag.Annotations[annotationGroupDependent] + expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b" + if len(annA) != 1 || annA[0] != expected1 { + t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA) + } + + // Also check that flag "b" has the correct annotation. + bFlag := f.Lookup("b") + if bFlag == nil { + t.Fatal("Flag 'b' not found") + } + annB := bFlag.Annotations[annotationGroupDependent] + if len(annB) != 1 || annB[0] != expected1 { + t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB) + } + + // Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c". + cmd.MarkIfFlagPresentThenOthersRequired("a", "c") + + // The annotation for flag "a" should now have both groups: "a b" and "a c" + annA = aFlag.Annotations[annotationGroupDependent] + expectedAnnotations := []string{"a b", "a c"} + if len(annA) != 2 { + t.Errorf("Expected 2 annotations on flag 'a', got %v", annA) + } + // Check that both expected annotation strings are present. + for _, expected := range expectedAnnotations { + found := false + for _, ann := range annA { + if ann == expected { + found = true + break + } + } + if !found { + t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA) + } + } + + // Similarly, check that flag "c" now has the annotation "a c". + cFlag := f.Lookup("c") + if cFlag == nil { + t.Fatal("Flag 'c' not found") + } + annC := cFlag.Annotations[annotationGroupDependent] + expected2 := "a c" + if len(annC) != 1 || annC[0] != expected2 { + t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC) + } +}