diff --git a/flag_groups.go b/flag_groups.go index 49721430..62eb6404 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,10 +23,10 @@ import ( ) const ( - requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" - oneRequiredAnnotation = "cobra_annotation_one_required" - mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" - ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required" + annotationGroupRequired = "cobra_annotation_required_if_others_set" + annotationRequiredOne = "cobra_annotation_one_required" + annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive" + annotationGroupDependent = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -38,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) } - if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -54,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) } - if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -71,7 +71,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -90,7 +90,7 @@ func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequiredAnnotation, append(f.Annotations[ifPresentThenOthersRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -112,10 +112,10 @@ func (c *Command) ValidateFlagGroups() error { mutuallyExclusiveGroupStatus := map[string]map[string]bool{} ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenOthersRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -291,10 +291,10 @@ func (c *Command) enforceFlagGroupsForCompletion() { mutuallyExclusiveGroupStatus := map[string]map[string]bool{} ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) - processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags diff --git a/flag_groups_test.go b/flag_groups_test.go index 158a138c..7b602bc9 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -213,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) { }) } } + +func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) { + // Create a new command with some flags. + cmd := &Command{ + Use: "testcmd", + } + f := cmd.Flags() + f.String("a", "", "flag a") + f.String("b", "", "flag b") + f.String("c", "", "flag c") + + // Call the function with one group: ["a", "b"]. + cmd.MarkIfFlagPresentThenOthersRequired("a", "b") + + // Check that flag "a" has the correct annotation. + aFlag := f.Lookup("a") + if aFlag == nil { + t.Fatal("Flag 'a' not found") + } + annA := aFlag.Annotations[annotationGroupDependent] + expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b" + if len(annA) != 1 || annA[0] != expected1 { + t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA) + } + + // Also check that flag "b" has the correct annotation. + bFlag := f.Lookup("b") + if bFlag == nil { + t.Fatal("Flag 'b' not found") + } + annB := bFlag.Annotations[annotationGroupDependent] + if len(annB) != 1 || annB[0] != expected1 { + t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB) + } + + // Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c". + cmd.MarkIfFlagPresentThenOthersRequired("a", "c") + + // The annotation for flag "a" should now have both groups: "a b" and "a c" + annA = aFlag.Annotations[annotationGroupDependent] + expectedAnnotations := []string{"a b", "a c"} + if len(annA) != 2 { + t.Errorf("Expected 2 annotations on flag 'a', got %v", annA) + } + // Check that both expected annotation strings are present. + for _, expected := range expectedAnnotations { + found := false + for _, ann := range annA { + if ann == expected { + found = true + break + } + } + if !found { + t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA) + } + } + + // Similarly, check that flag "c" now has the annotation "a c". + cFlag := f.Lookup("c") + if cFlag == nil { + t.Fatal("Flag 'c' not found") + } + annC := cFlag.Annotations[annotationGroupDependent] + expected2 := "a c" + if len(annC) != 1 || annC[0] != expected2 { + t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC) + } +}