mirror of
https://github.com/spf13/cobra
synced 2025-04-02 04:49:17 +00:00
Merge 8d0e71e82e
into ceb39aba25
This commit is contained in:
commit
9d80823f30
2 changed files with 186 additions and 24 deletions
|
@ -23,9 +23,10 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
|
||||
oneRequiredAnnotation = "cobra_annotation_one_required"
|
||||
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
|
||||
annotationGroupRequired = "cobra_annotation_required_if_others_set"
|
||||
annotationRequiredOne = "cobra_annotation_one_required"
|
||||
annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive"
|
||||
annotationGroupDependent = "cobra_annotation_if_present_then_others_required"
|
||||
)
|
||||
|
||||
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
||||
|
@ -37,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
|
|||
if f == nil {
|
||||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
|
||||
}
|
||||
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
|
||||
if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil {
|
||||
// Only errs if the flag isn't found.
|
||||
panic(err)
|
||||
}
|
||||
|
@ -53,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
|
|||
if f == nil {
|
||||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
|
||||
}
|
||||
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
|
||||
if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil {
|
||||
// Only errs if the flag isn't found.
|
||||
panic(err)
|
||||
}
|
||||
|
@ -70,7 +71,26 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
|||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
|
||||
}
|
||||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
||||
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
|
||||
if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
|
||||
// all the other flags become required.
|
||||
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
|
||||
if len(flagNames) < 2 {
|
||||
panic("MarkIfFlagPresentThenRequired requires at least two flags")
|
||||
}
|
||||
c.mergePersistentFlags()
|
||||
for _, v := range flagNames {
|
||||
f := c.Flags().Lookup(v)
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
|
||||
}
|
||||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
||||
if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error {
|
|||
groupStatus := map[string]map[string]bool{}
|
||||
oneRequiredGroupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
|
||||
flags.VisitAll(func(pflag *flag.Flag) {
|
||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus)
|
||||
})
|
||||
|
||||
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
||||
|
@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error {
|
|||
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error {
|
||||
for flagList, flagnameAndStatus := range data {
|
||||
flags := strings.Split(flagList, " ")
|
||||
primaryFlag := flags[0]
|
||||
remainingFlags := flags[1:]
|
||||
|
||||
// Handle missing primary flag entry
|
||||
if _, exists := flagnameAndStatus[primaryFlag]; !exists {
|
||||
flagnameAndStatus[primaryFlag] = false
|
||||
}
|
||||
|
||||
// Check if the primary flag is set
|
||||
if flagnameAndStatus[primaryFlag] {
|
||||
var unset []string
|
||||
for _, flag := range remainingFlags {
|
||||
if !flagnameAndStatus[flag] {
|
||||
unset = append(unset, flag)
|
||||
}
|
||||
}
|
||||
|
||||
// If any dependent flags are unset, trigger an error
|
||||
if len(unset) > 0 {
|
||||
return fmt.Errorf(
|
||||
"%v is set, the following flags must be provided: %v",
|
||||
primaryFlag, unset,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string]map[string]bool) []string {
|
||||
keys := make([]string, len(m))
|
||||
i := 0
|
||||
|
@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
|
|||
// - when a flag in a group is present, other flags in the group will be marked required
|
||||
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
|
||||
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
|
||||
// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required
|
||||
// This allows the standard completion logic to behave appropriately for flag groups
|
||||
func (c *Command) enforceFlagGroupsForCompletion() {
|
||||
if c.DisableFlagParsing {
|
||||
|
@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
|
|||
groupStatus := map[string]map[string]bool{}
|
||||
oneRequiredGroupStatus := map[string]map[string]bool{}
|
||||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
||||
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
|
||||
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
||||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus)
|
||||
processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus)
|
||||
})
|
||||
|
||||
// If a flag that is part of a group is present, we make all the other flags
|
||||
|
@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a flag that is marked as if-present-then-required is present, make other flags in the group required
|
||||
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus {
|
||||
flags := strings.Split(flagList, " ")
|
||||
primaryFlag := flags[0]
|
||||
remainingFlags := flags[1:]
|
||||
|
||||
if flagnameAndStatus[primaryFlag] {
|
||||
for _, fName := range remainingFlags {
|
||||
_ = c.MarkFlagRequired(fName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) {
|
|||
|
||||
// Each test case uses a unique command from the function above.
|
||||
testcases := []struct {
|
||||
desc string
|
||||
flagGroupsRequired []string
|
||||
flagGroupsOneRequired []string
|
||||
flagGroupsExclusive []string
|
||||
subCmdFlagGroupsRequired []string
|
||||
subCmdFlagGroupsOneRequired []string
|
||||
subCmdFlagGroupsExclusive []string
|
||||
args []string
|
||||
expectErr string
|
||||
desc string
|
||||
flagGroupsRequired []string
|
||||
flagGroupsOneRequired []string
|
||||
flagGroupsExclusive []string
|
||||
flagGroupsIfPresentThenRequired []string
|
||||
subCmdFlagGroupsRequired []string
|
||||
subCmdFlagGroupsOneRequired []string
|
||||
subCmdFlagGroupsExclusive []string
|
||||
subCmdFlagGroupsIfPresentThenRequired []string
|
||||
args []string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
desc: "No flags no problem",
|
||||
}, {
|
||||
desc: "No flags no problem even with conflicting groups",
|
||||
flagGroupsRequired: []string{"a b"},
|
||||
flagGroupsExclusive: []string{"a b"},
|
||||
desc: "No flags no problem even with conflicting groups",
|
||||
flagGroupsRequired: []string{"a b"},
|
||||
flagGroupsExclusive: []string{"a b"},
|
||||
flagGroupsIfPresentThenRequired: []string{"a b", "b a"},
|
||||
}, {
|
||||
desc: "Required flag group not satisfied",
|
||||
flagGroupsRequired: []string{"a b c"},
|
||||
|
@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) {
|
|||
flagGroupsExclusive: []string{"a b c"},
|
||||
args: []string{"--a=foo", "--b=foo"},
|
||||
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
|
||||
}, {
|
||||
desc: "If present then others required flag group not satisfied",
|
||||
flagGroupsIfPresentThenRequired: []string{"a b"},
|
||||
args: []string{"--a=foo"},
|
||||
expectErr: "a is set, the following flags must be provided: [b]",
|
||||
}, {
|
||||
desc: "Multiple required flag group not satisfied returns first error",
|
||||
flagGroupsRequired: []string{"a b c", "a d"},
|
||||
|
@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) {
|
|||
flagGroupsExclusive: []string{"a b c", "a d"},
|
||||
args: []string{"--a=foo", "--c=foo", "--d=foo"},
|
||||
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`,
|
||||
},
|
||||
{
|
||||
desc: "Multiple if present then others required flag group not satisfied returns first error",
|
||||
flagGroupsIfPresentThenRequired: []string{"a b", "d e"},
|
||||
args: []string{"--a=foo", "--f=foo"},
|
||||
expectErr: `a is set, the following flags must be provided: [b]`,
|
||||
}, {
|
||||
desc: "Validation of required groups occurs on groups in sorted order",
|
||||
flagGroupsRequired: []string{"a d", "a b", "a c"},
|
||||
|
@ -182,6 +196,12 @@ func TestValidateFlagGroups(t *testing.T) {
|
|||
for _, flagGroup := range tc.subCmdFlagGroupsExclusive {
|
||||
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
|
||||
}
|
||||
for _, flagGroup := range tc.flagGroupsIfPresentThenRequired {
|
||||
c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
|
||||
}
|
||||
for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired {
|
||||
sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
|
||||
}
|
||||
c.SetArgs(tc.args)
|
||||
err := c.Execute()
|
||||
switch {
|
||||
|
@ -193,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) {
|
||||
// Create a new command with some flags.
|
||||
cmd := &Command{
|
||||
Use: "testcmd",
|
||||
}
|
||||
f := cmd.Flags()
|
||||
f.String("a", "", "flag a")
|
||||
f.String("b", "", "flag b")
|
||||
f.String("c", "", "flag c")
|
||||
|
||||
// Call the function with one group: ["a", "b"].
|
||||
cmd.MarkIfFlagPresentThenOthersRequired("a", "b")
|
||||
|
||||
// Check that flag "a" has the correct annotation.
|
||||
aFlag := f.Lookup("a")
|
||||
if aFlag == nil {
|
||||
t.Fatal("Flag 'a' not found")
|
||||
}
|
||||
annA := aFlag.Annotations[annotationGroupDependent]
|
||||
expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b"
|
||||
if len(annA) != 1 || annA[0] != expected1 {
|
||||
t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA)
|
||||
}
|
||||
|
||||
// Also check that flag "b" has the correct annotation.
|
||||
bFlag := f.Lookup("b")
|
||||
if bFlag == nil {
|
||||
t.Fatal("Flag 'b' not found")
|
||||
}
|
||||
annB := bFlag.Annotations[annotationGroupDependent]
|
||||
if len(annB) != 1 || annB[0] != expected1 {
|
||||
t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB)
|
||||
}
|
||||
|
||||
// Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c".
|
||||
cmd.MarkIfFlagPresentThenOthersRequired("a", "c")
|
||||
|
||||
// The annotation for flag "a" should now have both groups: "a b" and "a c"
|
||||
annA = aFlag.Annotations[annotationGroupDependent]
|
||||
expectedAnnotations := []string{"a b", "a c"}
|
||||
if len(annA) != 2 {
|
||||
t.Errorf("Expected 2 annotations on flag 'a', got %v", annA)
|
||||
}
|
||||
// Check that both expected annotation strings are present.
|
||||
for _, expected := range expectedAnnotations {
|
||||
found := false
|
||||
for _, ann := range annA {
|
||||
if ann == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA)
|
||||
}
|
||||
}
|
||||
|
||||
// Similarly, check that flag "c" now has the annotation "a c".
|
||||
cFlag := f.Lookup("c")
|
||||
if cFlag == nil {
|
||||
t.Fatal("Flag 'c' not found")
|
||||
}
|
||||
annC := cFlag.Annotations[annotationGroupDependent]
|
||||
expected2 := "a c"
|
||||
if len(annC) != 1 || annC[0] != expected2 {
|
||||
t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue