spf13--cobra/flag_groups.go
evermake 98c9b4c903 refactor(flag_groups): flag groups implementation changed
This commit changes the flag groups feature logic. New implementation is more clean, readable and extendable (hope it won't be just my opinion).

The following changes have been made:

1. Main change:
Flags annotating by "cobra_annotation_required_if_others_set" and "cobra_annotation_mutually_exclusive" annotations was removed as well as all related and hard-to-understand "hacks" to combine flags back into groups on validation process.
Instead, `flagGroups` field was added to the `Command` struct. `flagGroups` field is a list of (new) structs `flagGroup`, which represents the "relationships" between flags within the command.

2. "Required together" and "mutually exclusive" groups logic was updated by implementing `requiredTogetherFlagGroup` and `mutuallyExclusiveFlagGroup` `flagGroup`s.

3. `enforceFlagGroupsForCompletion` `Command`'s method was renamed to `adjustByFlagGroupsForCompletions`.

4. Groups failed validation error messages were changed:
  - `"if any flags in the group [...] are set they must all be set; missing [...]"` to `"flags [...] must be set together, but [...] were not set"`
  - `"if any flags in the group [...] are set none of the others can be; [...] were all set"` to `"exactly one of the flags [...] can be set, but [...] were set"`

5. Not found flag on group marking error messages were updated from "Failed to find flag %q and mark it as being required in a flag group" and "Failed to find flag %q and mark it as being in a mutually exclusive flag group" to "flag %q is not defined"

6. `TestValidateFlagGroups` test was updated in `flag_groups_test.go`.
  - `getCmd` function was updated and test flag names were changed to improve readability
  - 2 testcases (`Validation of required groups occurs on groups in sorted order` and `Validation of exclusive groups occurs on groups in sorted order`) were removed, because groups validation now occur in the same order those groups were registered
  - other 16 testcases are preserved with updated descriptions, error messages

The completions generation tests that contain flag groups related testcases and updated flag groups tests, as well as all other tests, have been passed.

API was not changed: `MarkFlagsRequiredTogether` and `MarkFlagsMutuallyExclusive` functions have the same signatures.
2022-10-07 12:09:03 +03:00

196 lines
6.1 KiB
Go

// Copyright 2013-2022 The Cobra Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cobra
import (
"fmt"
flag "github.com/spf13/pflag"
)
// MarkFlagsRequiredTogether creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must be set too.
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
c.addFlagGroup(&requiredTogetherFlagGroup{
flagNames: flagNames,
})
}
// MarkFlagsMutuallyExclusive creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must not be set.
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
c.addFlagGroup(&mutuallyExclusiveFlagGroup{
flagNames: flagNames,
})
}
// addFlagGroup merges persistent flags of the command and adds flagGroup into command's flagGroups list.
// Panics, if flagGroup g contains the name of the flag, which is not defined in the Command c.
func (c *Command) addFlagGroup(g flagGroup) {
c.mergePersistentFlags()
for _, flagName := range g.AssignedFlagNames() {
if c.Flags().Lookup(flagName) == nil {
panic(fmt.Sprintf("flag %q is not defined", flagName))
}
}
c.flagGroups = append(c.flagGroups, g)
}
// ValidateFlagGroups runs validation for each group from command's flagGroups list,
// and returns the first error encountered, or nil, if there were no validation errors.
func (c *Command) ValidateFlagGroups() error {
setFlags := makeSetFlagsSet(c.Flags())
for _, group := range c.flagGroups {
if err := group.ValidateSetFlags(setFlags); err != nil {
return err
}
}
return nil
}
// adjustByFlagGroupsForCompletions changes the command by each flagGroup from command's flagGroups list
// to make the further command completions generation more convenient.
// Does nothing, if Command.DisableFlagParsing is true.
func (c *Command) adjustByFlagGroupsForCompletions() {
if c.DisableFlagParsing {
return
}
for _, group := range c.flagGroups {
group.AdjustCommandForCompletions(c)
}
}
type flagGroup interface {
// ValidateSetFlags checks whether the combination of flags that have been set is valid.
// If not, an error is returned.
ValidateSetFlags(setFlags setFlagsSet) error
// AssignedFlagNames returns a full list of flag names that have been assigned to the group.
AssignedFlagNames() []string
// AdjustCommandForCompletions updates the command to generate more convenient for this group completions.
AdjustCommandForCompletions(c *Command)
}
// requiredTogetherFlagGroup groups flags that are required together and
// must all be set, if any of flags from this group is set.
type requiredTogetherFlagGroup struct {
flagNames []string
}
func (g *requiredTogetherFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *requiredTogetherFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
unset := setFlags.selectUnsetFlagNamesFrom(g.flagNames)
if unsetCount := len(unset); unsetCount != 0 && unsetCount != len(g.flagNames) {
return fmt.Errorf("flags %v must be set together, but %v were not set", g.flagNames, unset)
}
return nil
}
func (g *requiredTogetherFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
if setFlags.hasAnyFrom(g.flagNames) {
for _, requiredFlagName := range g.flagNames {
_ = c.MarkFlagRequired(requiredFlagName)
}
}
}
// mutuallyExclusiveFlagGroup groups flags that are mutually exclusive
// and must not be set together, if any of flags from this group is set.
type mutuallyExclusiveFlagGroup struct {
flagNames []string
}
func (g *mutuallyExclusiveFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *mutuallyExclusiveFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
set := setFlags.selectSetFlagNamesFrom(g.flagNames)
if len(set) > 1 {
return fmt.Errorf("exactly one of the flags %v can be set, but %v were set", g.flagNames, set)
}
return nil
}
func (g *mutuallyExclusiveFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
firstSetFlagName, hasAny := setFlags.selectFirstSetFlagNameFrom(g.flagNames)
if hasAny {
for _, exclusiveFlagName := range g.flagNames {
if exclusiveFlagName != firstSetFlagName {
c.Flags().Lookup(exclusiveFlagName).Hidden = true
}
}
}
}
// setFlagsSet is a helper set type that is intended to be used to store names of the flags
// that have been set in flag.FlagSet and to perform some lookups and checks on those flags.
type setFlagsSet map[string]struct{}
// makeSetFlagsSet creates setFlagsSet of names of the flags that have been set in the given flag.FlagSet.
func makeSetFlagsSet(fs *flag.FlagSet) setFlagsSet {
s := make(setFlagsSet)
// Visit flags that have been set and add them to the set
fs.Visit(func(f *flag.Flag) {
s[f.Name] = struct{}{}
})
return s
}
func (s setFlagsSet) has(flagName string) bool {
_, ok := s[flagName]
return ok
}
func (s setFlagsSet) hasAnyFrom(flagNames []string) bool {
for _, flagName := range flagNames {
if s.has(flagName) {
return true
}
}
return false
}
func (s setFlagsSet) selectFirstSetFlagNameFrom(flagNames []string) (string, bool) {
for _, flagName := range flagNames {
if s.has(flagName) {
return flagName, true
}
}
return "", false
}
func (s setFlagsSet) selectSetFlagNamesFrom(flagNames []string) (setFlagNames []string) {
for _, flagName := range flagNames {
if s.has(flagName) {
setFlagNames = append(setFlagNames, flagName)
}
}
return
}
func (s setFlagsSet) selectUnsetFlagNamesFrom(flagNames []string) (unsetFlagNames []string) {
for _, flagName := range flagNames {
if !s.has(flagName) {
unsetFlagNames = append(unsetFlagNames, flagName)
}
}
return
}