1
0
Fork 0
mirror of https://github.com/spf13/cobra synced 2025-04-04 22:09:11 +00:00

add explicit test for testing annotation

This commit is contained in:
faizan-siddiqui 2025-02-18 05:17:53 +00:00
parent a9bed3a190
commit 578b0744b2
2 changed files with 85 additions and 16 deletions

View file

@ -23,10 +23,10 @@ import (
)
const (
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required"
annotationGroupRequired = "cobra_annotation_required_if_others_set"
annotationRequiredOne = "cobra_annotation_one_required"
annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive"
annotationGroupDependent = "cobra_annotation_if_present_then_others_required"
)
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
@ -38,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
@ -54,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
}
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
@ -71,7 +71,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
@ -90,7 +90,7 @@ func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequiredAnnotation, append(f.Annotations[ifPresentThenOthersRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
@ -112,10 +112,10 @@ func (c *Command) ValidateFlagGroups() error {
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenOthersRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus)
})
if err := validateRequiredFlagGroups(groupStatus); err != nil {
@ -291,10 +291,10 @@ func (c *Command) enforceFlagGroupsForCompletion() {
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus)
})
// If a flag that is part of a group is present, we make all the other flags

View file

@ -213,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) {
})
}
}
func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) {
// Create a new command with some flags.
cmd := &Command{
Use: "testcmd",
}
f := cmd.Flags()
f.String("a", "", "flag a")
f.String("b", "", "flag b")
f.String("c", "", "flag c")
// Call the function with one group: ["a", "b"].
cmd.MarkIfFlagPresentThenOthersRequired("a", "b")
// Check that flag "a" has the correct annotation.
aFlag := f.Lookup("a")
if aFlag == nil {
t.Fatal("Flag 'a' not found")
}
annA := aFlag.Annotations[annotationGroupDependent]
expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b"
if len(annA) != 1 || annA[0] != expected1 {
t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA)
}
// Also check that flag "b" has the correct annotation.
bFlag := f.Lookup("b")
if bFlag == nil {
t.Fatal("Flag 'b' not found")
}
annB := bFlag.Annotations[annotationGroupDependent]
if len(annB) != 1 || annB[0] != expected1 {
t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB)
}
// Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c".
cmd.MarkIfFlagPresentThenOthersRequired("a", "c")
// The annotation for flag "a" should now have both groups: "a b" and "a c"
annA = aFlag.Annotations[annotationGroupDependent]
expectedAnnotations := []string{"a b", "a c"}
if len(annA) != 2 {
t.Errorf("Expected 2 annotations on flag 'a', got %v", annA)
}
// Check that both expected annotation strings are present.
for _, expected := range expectedAnnotations {
found := false
for _, ann := range annA {
if ann == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA)
}
}
// Similarly, check that flag "c" now has the annotation "a c".
cFlag := f.Lookup("c")
if cFlag == nil {
t.Fatal("Flag 'c' not found")
}
annC := cFlag.Annotations[annotationGroupDependent]
expected2 := "a c"
if len(annC) != 1 || annC[0] != expected2 {
t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC)
}
}