From 68b6b24f0c9926a779f7113d1040396a30fdedaf Mon Sep 17 00:00:00 2001 From: John Schnake Date: Sun, 17 Apr 2022 16:04:57 -0500 Subject: [PATCH] Add ability to mark flags as required or exclusive as a group (#1654) This change adds two features for dealing with flags: - requiring flags be provided as a group (or not at all) - requiring flags be mutually exclusive of each other By utilizing the flag annotations we can mark which flag groups a flag is a part of and during the parsing process we track which ones we have seen or not. A flag may be a part of multiple groups. The list of flags and the type of group (required together or exclusive) make it a unique group. Signed-off-by: John Schnake --- command.go | 4 + flag_groups.go | 174 ++++++++++++++++++++++++++++++++++++++++++++ flag_groups_test.go | 151 ++++++++++++++++++++++++++++++++++++++ go.sum | 8 +- user_guide.md | 24 ++++++ 5 files changed, 354 insertions(+), 7 deletions(-) create mode 100644 flag_groups.go create mode 100644 flag_groups_test.go diff --git a/command.go b/command.go index ee5365bc..0f4511f3 100644 --- a/command.go +++ b/command.go @@ -863,6 +863,10 @@ func (c *Command) execute(a []string) (err error) { if err := c.validateRequiredFlags(); err != nil { return err } + if err := c.validateFlagGroups(); err != nil { + return err + } + if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { return err diff --git a/flag_groups.go b/flag_groups.go new file mode 100644 index 00000000..1dba424a --- /dev/null +++ b/flag_groups.go @@ -0,0 +1,174 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "fmt" + "sort" + "strings" + + flag "github.com/spf13/pflag" +) + +const ( + requiredAsGroup = "cobra_annotation_required_if_others_set" + mutuallyExclusive = "cobra_annotation_mutually_exclusive" +) + +// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors +// if the command is invoked with a subset (but not all) of the given flags. +func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) + } + if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { + // Only errs if the flag isn't found. + panic(err) + } + } +} + +// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors +// if the command is invoked with more than one flag from the given set of flags. +func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + +// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the +// first error encountered. +func (c *Command) validateFlagGroups() error { + if c.DisableFlagParsing { + return nil + } + + flags := c.Flags() + + // groupStatus format is the list of flags as a unique ID, + // then a map of each flag name and whether it is set or not. + groupStatus := map[string]map[string]bool{} + mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + flags.VisitAll(func(pflag *flag.Flag) { + processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + }) + + if err := validateRequiredFlagGroups(groupStatus); err != nil { + return err + } + if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { + return err + } + return nil +} + +func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { + for _, fname := range flagnames { + f := fs.Lookup(fname) + if f == nil { + return false + } + } + return true +} + +func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { + groupInfo, found := pflag.Annotations[annotation] + if found { + for _, group := range groupInfo { + if groupStatus[group] == nil { + flagnames := strings.Split(group, " ") + + // Only consider this flag group at all if all the flags are defined. + if !hasAllFlags(flags, flagnames...) { + continue + } + + groupStatus[group] = map[string]bool{} + for _, name := range flagnames { + groupStatus[group][name] = false + } + } + + groupStatus[group][pflag.Name] = pflag.Changed + } + } +} + +func validateRequiredFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + + unset := []string{} + for flagname, isSet := range flagnameAndStatus { + if !isSet { + unset = append(unset, flagname) + } + } + if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { + continue + } + + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(unset) + return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) + } + + return nil +} + +func validateExclusiveFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) == 0 || len(set) == 1 { + continue + } + + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) + } + return nil +} + +func sortedKeys(m map[string]map[string]bool) []string { + keys := make([]string, len(m)) + i := 0 + for k := range m { + keys[i] = k + i++ + } + sort.Strings(keys) + return keys +} diff --git a/flag_groups_test.go b/flag_groups_test.go new file mode 100644 index 00000000..404ede56 --- /dev/null +++ b/flag_groups_test.go @@ -0,0 +1,151 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "strings" + "testing" +) + +func TestValidateFlagGroups(t *testing.T) { + getCmd := func() *Command { + c := &Command{ + Use: "testcmd", + Run: func(cmd *Command, args []string) { + }} + // Define lots of flags to utilize for testing. + for _, v := range []string{"a", "b", "c", "d"} { + c.Flags().String(v, "", "") + } + for _, v := range []string{"e", "f", "g"} { + c.PersistentFlags().String(v, "", "") + } + subC := &Command{ + Use: "subcmd", + Run: func(cmd *Command, args []string) { + }} + subC.Flags().String("subonly", "", "") + c.AddCommand(subC) + return c + } + + // Each test case uses a unique command from the function above. + testcases := []struct { + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsExclusive []string + args []string + expectErr string + }{ + { + desc: "No flags no problem", + }, { + desc: "No flags no problem even with conflicting groups", + flagGroupsRequired: []string{"a b"}, + flagGroupsExclusive: []string{"a b"}, + }, { + desc: "Required flag group not satisfied", + flagGroupsRequired: []string{"a b c"}, + args: []string{"--a=foo"}, + expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + }, { + desc: "Exclusive flag group not satisfied", + flagGroupsExclusive: []string{"a b c"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "Multiple required flag group not satisfied returns first error", + flagGroupsRequired: []string{"a b c", "a d"}, + args: []string{"--c=foo", "--d=foo"}, + expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`, + }, { + desc: "Multiple exclusive flag group not satisfied returns first error", + flagGroupsExclusive: []string{"a b c", "a d"}, + args: []string{"--a=foo", "--c=foo", "--d=foo"}, + expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, + }, { + desc: "Validation of required groups occurs on groups in sorted order", + flagGroupsRequired: []string{"a d", "a b", "a c"}, + args: []string{"--a=foo"}, + expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, + }, { + desc: "Validation of exclusive groups occurs on groups in sorted order", + flagGroupsExclusive: []string{"a d", "a b", "a c"}, + args: []string{"--a=foo", "--b=foo", "--c=foo"}, + expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, + }, { + desc: "Persistent flags utilize both features and can fail required groups", + flagGroupsRequired: []string{"a e", "e f"}, + flagGroupsExclusive: []string{"f g"}, + args: []string{"--a=foo", "--f=foo", "--g=foo"}, + expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, + }, { + desc: "Persistent flags utilize both features and can fail mutually exclusive groups", + flagGroupsRequired: []string{"a e", "e f"}, + flagGroupsExclusive: []string{"f g"}, + args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, + expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, + }, { + desc: "Persistent flags utilize both features and can pass", + flagGroupsRequired: []string{"a e", "e f"}, + flagGroupsExclusive: []string{"f g"}, + args: []string{"--a=foo", "--e=foo", "--f=foo"}, + }, { + desc: "Subcmds can use required groups using inherited flags", + subCmdFlagGroupsRequired: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + }, { + desc: "Subcmds can use exclusive groups using inherited flags", + subCmdFlagGroupsExclusive: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + expectErr: "if any flags in the group [e subonly] are set none of the others can be; [e subonly] were all set", + }, { + desc: "Subcmds can use exclusive groups using inherited flags and pass", + subCmdFlagGroupsExclusive: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo"}, + }, { + desc: "Flag groups not applied if not found on invoked command", + subCmdFlagGroupsRequired: []string{"e subonly"}, + args: []string{"--e=foo"}, + }, + } + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + c := getCmd() + sub := c.Commands()[0] + for _, flagGroup := range tc.flagGroupsRequired { + c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.flagGroupsExclusive { + c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsRequired { + sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsExclusive { + sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + } + c.SetArgs(tc.args) + err := c.Execute() + switch { + case err == nil && len(tc.expectErr) > 0: + t.Errorf("Expected error %q but got nil", tc.expectErr) + case err != nil && err.Error() != tc.expectErr: + t.Errorf("Expected error %q but got %q", tc.expectErr, err) + } + }) + } +} diff --git a/go.sum b/go.sum index 431058ed..6d534596 100644 --- a/go.sum +++ b/go.sum @@ -2,17 +2,11 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKY github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/user_guide.md b/user_guide.md index cbf9a897..56a1e9c6 100644 --- a/user_guide.md +++ b/user_guide.md @@ -300,6 +300,30 @@ rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (re rootCmd.MarkPersistentFlagRequired("region") ``` +### Flag Groups + +If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then +Cobra can enforce that requirement: +```go +rootCmd.Flags().StringVarP(&u, "username", "u", "", "Username (required if password is set)") +rootCmd.Flags().StringVarP(&pw, "password", "p", "", "Password (required if username is set)") +rootCmd.MarkFlagsRequiredTogether("username", "password") +``` + +You can also prevent different flags from being provided together if they represent mutually +exclusive options such as specifying an output format as either `--json` or `--yaml` but never both: +```go +rootCmd.Flags().BoolVar(&u, "json", false, "Output in JSON") +rootCmd.Flags().BoolVar(&pw, "yaml", false, "Output in YAML") +rootCmd.MarkFlagsMutuallyExclusive("json", "yaml") +``` + +In both of these cases: + - both local and persistent flags can be used + - **NOTE:** the group is only enforced on commands where every flag is defined + - a flag may appear in multiple groups + - a group may contain any number of flags + ## Positional and Custom Arguments Validation of positional arguments can be specified using the `Args` field