2022-09-16 11:55:56 +00:00
|
|
|
// Copyright 2013-2022 The Cobra Authors
|
2022-04-17 21:04:57 +00:00
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
2022-09-16 11:55:56 +00:00
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
2022-04-17 21:04:57 +00:00
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
package cobra
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
|
|
|
|
flag "github.com/spf13/pflag"
|
|
|
|
)
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// MarkFlagsRequiredTogether creates a relationship between flags, which ensures
|
|
|
|
// that if any of flags with names from flagNames is set, other flags must be set too.
|
2022-04-17 21:04:57 +00:00
|
|
|
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
|
2022-08-11 13:35:14 +00:00
|
|
|
c.addFlagGroup(&requiredTogetherFlagGroup{
|
|
|
|
flagNames: flagNames,
|
|
|
|
})
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// MarkFlagsMutuallyExclusive creates a relationship between flags, which ensures
|
|
|
|
// that if any of flags with names from flagNames is set, other flags must not be set.
|
2022-04-17 21:04:57 +00:00
|
|
|
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
2022-08-11 13:35:14 +00:00
|
|
|
c.addFlagGroup(&mutuallyExclusiveFlagGroup{
|
|
|
|
flagNames: flagNames,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
// addFlagGroup merges persistent flags of the command and adds flagGroup into command's flagGroups list.
|
|
|
|
// Panics, if flagGroup g contains the name of the flag, which is not defined in the Command c.
|
|
|
|
func (c *Command) addFlagGroup(g flagGroup) {
|
2022-04-17 21:04:57 +00:00
|
|
|
c.mergePersistentFlags()
|
2022-08-11 13:35:14 +00:00
|
|
|
|
|
|
|
for _, flagName := range g.AssignedFlagNames() {
|
|
|
|
if c.Flags().Lookup(flagName) == nil {
|
|
|
|
panic(fmt.Sprintf("flag %q is not defined", flagName))
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
|
|
|
|
c.flagGroups = append(c.flagGroups, g)
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// ValidateFlagGroups runs validation for each group from command's flagGroups list,
|
|
|
|
// and returns the first error encountered, or nil, if there were no validation errors.
|
2022-09-27 10:27:48 +00:00
|
|
|
func (c *Command) ValidateFlagGroups() error {
|
2022-08-11 13:35:14 +00:00
|
|
|
setFlags := makeSetFlagsSet(c.Flags())
|
|
|
|
for _, group := range c.flagGroups {
|
|
|
|
if err := group.ValidateSetFlags(setFlags); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
return nil
|
|
|
|
}
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// adjustByFlagGroupsForCompletions changes the command by each flagGroup from command's flagGroups list
|
|
|
|
// to make the further command completions generation more convenient.
|
|
|
|
// Does nothing, if Command.DisableFlagParsing is true.
|
|
|
|
func (c *Command) adjustByFlagGroupsForCompletions() {
|
|
|
|
if c.DisableFlagParsing {
|
|
|
|
return
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
|
|
|
|
for _, group := range c.flagGroups {
|
|
|
|
group.AdjustCommandForCompletions(c)
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
type flagGroup interface {
|
|
|
|
// ValidateSetFlags checks whether the combination of flags that have been set is valid.
|
|
|
|
// If not, an error is returned.
|
|
|
|
ValidateSetFlags(setFlags setFlagsSet) error
|
|
|
|
|
|
|
|
// AssignedFlagNames returns a full list of flag names that have been assigned to the group.
|
|
|
|
AssignedFlagNames() []string
|
|
|
|
|
|
|
|
// AdjustCommandForCompletions updates the command to generate more convenient for this group completions.
|
|
|
|
AdjustCommandForCompletions(c *Command)
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// requiredTogetherFlagGroup groups flags that are required together and
|
|
|
|
// must all be set, if any of flags from this group is set.
|
|
|
|
type requiredTogetherFlagGroup struct {
|
|
|
|
flagNames []string
|
|
|
|
}
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
func (g *requiredTogetherFlagGroup) AssignedFlagNames() []string {
|
|
|
|
return g.flagNames
|
|
|
|
}
|
|
|
|
func (g *requiredTogetherFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
|
|
|
|
unset := setFlags.selectUnsetFlagNamesFrom(g.flagNames)
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
if unsetCount := len(unset); unsetCount != 0 && unsetCount != len(g.flagNames) {
|
|
|
|
return fmt.Errorf("flags %v must be set together, but %v were not set", g.flagNames, unset)
|
|
|
|
}
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
func (g *requiredTogetherFlagGroup) AdjustCommandForCompletions(c *Command) {
|
|
|
|
setFlags := makeSetFlagsSet(c.Flags())
|
|
|
|
if setFlags.hasAnyFrom(g.flagNames) {
|
|
|
|
for _, requiredFlagName := range g.flagNames {
|
|
|
|
_ = c.MarkFlagRequired(requiredFlagName)
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// mutuallyExclusiveFlagGroup groups flags that are mutually exclusive
|
|
|
|
// and must not be set together, if any of flags from this group is set.
|
|
|
|
type mutuallyExclusiveFlagGroup struct {
|
|
|
|
flagNames []string
|
|
|
|
}
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
func (g *mutuallyExclusiveFlagGroup) AssignedFlagNames() []string {
|
|
|
|
return g.flagNames
|
|
|
|
}
|
|
|
|
func (g *mutuallyExclusiveFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
|
|
|
|
set := setFlags.selectSetFlagNamesFrom(g.flagNames)
|
2022-04-17 21:04:57 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
if len(set) > 1 {
|
|
|
|
return fmt.Errorf("exactly one of the flags %v can be set, but %v were set", g.flagNames, set)
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
func (g *mutuallyExclusiveFlagGroup) AdjustCommandForCompletions(c *Command) {
|
|
|
|
setFlags := makeSetFlagsSet(c.Flags())
|
|
|
|
firstSetFlagName, hasAny := setFlags.selectFirstSetFlagNameFrom(g.flagNames)
|
|
|
|
if hasAny {
|
|
|
|
for _, exclusiveFlagName := range g.flagNames {
|
|
|
|
if exclusiveFlagName != firstSetFlagName {
|
|
|
|
c.Flags().Lookup(exclusiveFlagName).Hidden = true
|
2022-04-17 21:04:57 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// setFlagsSet is a helper set type that is intended to be used to store names of the flags
|
|
|
|
// that have been set in flag.FlagSet and to perform some lookups and checks on those flags.
|
|
|
|
type setFlagsSet map[string]struct{}
|
2022-06-21 02:04:28 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// makeSetFlagsSet creates setFlagsSet of names of the flags that have been set in the given flag.FlagSet.
|
|
|
|
func makeSetFlagsSet(fs *flag.FlagSet) setFlagsSet {
|
|
|
|
s := make(setFlagsSet)
|
2022-06-21 02:04:28 +00:00
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
// Visit flags that have been set and add them to the set
|
|
|
|
fs.Visit(func(f *flag.Flag) {
|
|
|
|
s[f.Name] = struct{}{}
|
2022-06-21 02:04:28 +00:00
|
|
|
})
|
|
|
|
|
2022-08-11 13:35:14 +00:00
|
|
|
return s
|
|
|
|
}
|
|
|
|
func (s setFlagsSet) has(flagName string) bool {
|
|
|
|
_, ok := s[flagName]
|
|
|
|
return ok
|
|
|
|
}
|
|
|
|
func (s setFlagsSet) hasAnyFrom(flagNames []string) bool {
|
|
|
|
for _, flagName := range flagNames {
|
|
|
|
if s.has(flagName) {
|
|
|
|
return true
|
2022-06-21 02:04:28 +00:00
|
|
|
}
|
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
func (s setFlagsSet) selectFirstSetFlagNameFrom(flagNames []string) (string, bool) {
|
|
|
|
for _, flagName := range flagNames {
|
|
|
|
if s.has(flagName) {
|
|
|
|
return flagName, true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return "", false
|
|
|
|
}
|
|
|
|
func (s setFlagsSet) selectSetFlagNamesFrom(flagNames []string) (setFlagNames []string) {
|
|
|
|
for _, flagName := range flagNames {
|
|
|
|
if s.has(flagName) {
|
|
|
|
setFlagNames = append(setFlagNames, flagName)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
func (s setFlagsSet) selectUnsetFlagNamesFrom(flagNames []string) (unsetFlagNames []string) {
|
|
|
|
for _, flagName := range flagNames {
|
|
|
|
if !s.has(flagName) {
|
|
|
|
unsetFlagNames = append(unsetFlagNames, flagName)
|
2022-06-21 02:04:28 +00:00
|
|
|
}
|
|
|
|
}
|
2022-08-11 13:35:14 +00:00
|
|
|
return
|
2022-06-21 02:04:28 +00:00
|
|
|
}
|