mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
3d8ac432bd
* Avoid redundant string splits There likely isn't actually more than once to split in the source strings in these cases, but avoid doing so anyway as we're only interested in the first. * Avoid redundant completion output target evaluations The target is not to be changed while outputting completions, so resolve it only once. * Avoid redundant active help enablement evaluations The enablement state is not to be changed during completion output, so evaluate it only once. * Preallocate some slices and maps with known size * Avoid some unnecessary looping * Use strings.Builder to construct suggestions
290 lines
9.2 KiB
Go
290 lines
9.2 KiB
Go
// Copyright 2013-2023 The Cobra Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cobra
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
flag "github.com/spf13/pflag"
|
|
)
|
|
|
|
const (
|
|
requiredAsGroup = "cobra_annotation_required_if_others_set"
|
|
oneRequired = "cobra_annotation_one_required"
|
|
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
|
|
)
|
|
|
|
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
|
// if the command is invoked with a subset (but not all) of the given flags.
|
|
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
|
|
}
|
|
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
|
|
// Only errs if the flag isn't found.
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
|
|
// if the command is invoked without at least one flag from the given set of flags.
|
|
func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
|
|
}
|
|
if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
|
|
// Only errs if the flag isn't found.
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
|
|
// if the command is invoked with more than one flag from the given set of flags.
|
|
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
|
|
}
|
|
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
|
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
|
|
// first error encountered.
|
|
func (c *Command) ValidateFlagGroups() error {
|
|
if c.DisableFlagParsing {
|
|
return nil
|
|
}
|
|
|
|
flags := c.Flags()
|
|
|
|
// groupStatus format is the list of flags as a unique ID,
|
|
// then a map of each flag name and whether it is set or not.
|
|
groupStatus := map[string]map[string]bool{}
|
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
|
flags.VisitAll(func(pflag *flag.Flag) {
|
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
|
|
})
|
|
|
|
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
|
return err
|
|
}
|
|
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
|
|
return err
|
|
}
|
|
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
|
|
for _, fname := range flagnames {
|
|
f := fs.Lookup(fname)
|
|
if f == nil {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
|
|
groupInfo, found := pflag.Annotations[annotation]
|
|
if found {
|
|
for _, group := range groupInfo {
|
|
if groupStatus[group] == nil {
|
|
flagnames := strings.Split(group, " ")
|
|
|
|
// Only consider this flag group at all if all the flags are defined.
|
|
if !hasAllFlags(flags, flagnames...) {
|
|
continue
|
|
}
|
|
|
|
groupStatus[group] = make(map[string]bool, len(flagnames))
|
|
for _, name := range flagnames {
|
|
groupStatus[group][name] = false
|
|
}
|
|
}
|
|
|
|
groupStatus[group][pflag.Name] = pflag.Changed
|
|
}
|
|
}
|
|
}
|
|
|
|
func validateRequiredFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
|
|
unset := []string{}
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if !isSet {
|
|
unset = append(unset, flagname)
|
|
}
|
|
}
|
|
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(unset)
|
|
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
var set []string
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
set = append(set, flagname)
|
|
}
|
|
}
|
|
if len(set) >= 1 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(set)
|
|
return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
var set []string
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
set = append(set, flagname)
|
|
}
|
|
}
|
|
if len(set) == 0 || len(set) == 1 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(set)
|
|
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sortedKeys(m map[string]map[string]bool) []string {
|
|
keys := make([]string, len(m))
|
|
i := 0
|
|
for k := range m {
|
|
keys[i] = k
|
|
i++
|
|
}
|
|
sort.Strings(keys)
|
|
return keys
|
|
}
|
|
|
|
// enforceFlagGroupsForCompletion will do the following:
|
|
// - when a flag in a group is present, other flags in the group will be marked required
|
|
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
|
|
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
|
|
// This allows the standard completion logic to behave appropriately for flag groups
|
|
func (c *Command) enforceFlagGroupsForCompletion() {
|
|
if c.DisableFlagParsing {
|
|
return
|
|
}
|
|
|
|
flags := c.Flags()
|
|
groupStatus := map[string]map[string]bool{}
|
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
|
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
|
|
})
|
|
|
|
// If a flag that is part of a group is present, we make all the other flags
|
|
// of that group required so that the shell completion suggests them automatically
|
|
for flagList, flagnameAndStatus := range groupStatus {
|
|
for _, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
// One of the flags of the group is set, mark the other ones as required
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
_ = c.MarkFlagRequired(fName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If none of the flags of a one-required group are present, we make all the flags
|
|
// of that group required so that the shell completion suggests them automatically
|
|
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
|
|
isSet := false
|
|
|
|
for _, isSet = range flagnameAndStatus {
|
|
if isSet {
|
|
break
|
|
}
|
|
}
|
|
|
|
// None of the flags of the group are set, mark all flags in the group
|
|
// as required
|
|
if !isSet {
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
_ = c.MarkFlagRequired(fName)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If a flag that is mutually exclusive to others is present, we hide the other
|
|
// flags of that group so the shell completion does not suggest them
|
|
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
|
|
for flagName, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
|
|
// Don't mark the flag that is already set as hidden because it may be an
|
|
// array or slice flag and therefore must continue being suggested
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
if fName != flagName {
|
|
flag := c.Flags().Lookup(fName)
|
|
flag.Hidden = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|