mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Explicitly support local flags overwriting persistent/inherited flags
The current (desired) behavior when a Command specifies a flag that has the same name as a persistent/inherited flag, is that the local definition takes precedence. This change updates the various Flag subset functions to respect that behavior: * LocalFlags: now returns only the set of flags and persistent flags attached to the Command itself. * InheritedFlags: now returns only the set of persistent flags inherited from the Command's parent(s), excluding any that are overwritten by a local flag. * NonInheritedFlags: changed to an alias of LocalFlags. * AllPersistentFlags: removed as not very useful; it returned the set of all persistent flags attached to the Command and its parent(s). Default UsageTemplate updated to use LocalFlags and InheritedFlags
This commit is contained in:
parent
f8e1ec56bd
commit
5c9146990b
2 changed files with 84 additions and 66 deletions
|
@ -11,11 +11,14 @@ var _ = fmt.Println
|
|||
|
||||
var tp, te, tt, t1 []string
|
||||
var flagb1, flagb2, flagb3, flagbr bool
|
||||
var flags1, flags2, flags3 string
|
||||
var flags1, flags2a, flags2b, flags3 string
|
||||
var flagi1, flagi2, flagi3, flagir int
|
||||
var globalFlag1 bool
|
||||
var flagEcho, rootcalled bool
|
||||
|
||||
const strtwoParentHelp = "help message for parent flag strtwo"
|
||||
const strtwoChildHelp = "help message for child flag strtwo"
|
||||
|
||||
var cmdPrint = &Command{
|
||||
Use: "print [string to print]",
|
||||
Short: "Print anything to the screen",
|
||||
|
@ -72,11 +75,12 @@ func flagInit() {
|
|||
cmdRootNoRun.ResetFlags()
|
||||
cmdRootSameName.ResetFlags()
|
||||
cmdRootWithRun.ResetFlags()
|
||||
cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp)
|
||||
cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone")
|
||||
cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo")
|
||||
cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree")
|
||||
cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone")
|
||||
cmdTimes.PersistentFlags().StringVarP(&flags2, "strtwo", "t", "two", "help message for flag strtwo")
|
||||
cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp)
|
||||
cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree")
|
||||
cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone")
|
||||
cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo")
|
||||
|
@ -381,6 +385,17 @@ func TestChildCommandFlags(t *testing.T) {
|
|||
t.Errorf("Wrong error message displayed, \n %s", r.Output)
|
||||
}
|
||||
|
||||
// Testing with persistent flag overwritten by child
|
||||
noRRSetupTest("echo times --strtwo=child one two")
|
||||
|
||||
if flags2b != "child" {
|
||||
t.Errorf("flag value should be child, %s given", flags2b)
|
||||
}
|
||||
|
||||
if flags2a != "two" {
|
||||
t.Errorf("unset flag should have default value, expecting two, given %s", flags2a)
|
||||
}
|
||||
|
||||
// Testing flag with invalid input
|
||||
r = noRRSetupTest("echo -i10E")
|
||||
|
||||
|
@ -437,6 +452,13 @@ func TestHelpCommand(t *testing.T) {
|
|||
checkResultContains(t, r, cmdTimes.Long)
|
||||
}
|
||||
|
||||
func TestChildCommandHelp(t *testing.T) {
|
||||
c := noRRSetupTest("print --help")
|
||||
checkResultContains(t, c, strtwoParentHelp)
|
||||
r := noRRSetupTest("echo times --help")
|
||||
checkResultContains(t, r, strtwoChildHelp)
|
||||
}
|
||||
|
||||
func TestRunnableRootCommand(t *testing.T) {
|
||||
fullSetupTest("")
|
||||
|
||||
|
@ -486,6 +508,26 @@ func TestRootHelp(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestFlagAccess(t *testing.T) {
|
||||
initialize()
|
||||
|
||||
local := cmdTimes.LocalFlags()
|
||||
inherited := cmdTimes.InheritedFlags()
|
||||
|
||||
for _, f := range []string{"inttwo", "strtwo", "booltwo"} {
|
||||
if local.Lookup(f) == nil {
|
||||
t.Errorf("LocalFlags expected to contain %s, Got: nil", f)
|
||||
}
|
||||
}
|
||||
if inherited.Lookup("strone") == nil {
|
||||
t.Errorf("InheritedFlags expected to contain strone, Got: nil")
|
||||
}
|
||||
if inherited.Lookup("strtwo") != nil {
|
||||
t.Errorf("InheritedFlags shouldn not contain overwritten flag strtwo")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootNoCommandHelp(t *testing.T) {
|
||||
x := rootOnlySetupTest("--help")
|
||||
|
||||
|
|
104
command.go
104
command.go
|
@ -46,6 +46,8 @@ type Command struct {
|
|||
flags *flag.FlagSet
|
||||
// Set of flags childrens of this command will inherit
|
||||
pflags *flag.FlagSet
|
||||
// Flags that are declared specifically by this command (not inherited).
|
||||
lflags *flag.FlagSet
|
||||
// Run runs the command.
|
||||
// The args are the arguments after the command name.
|
||||
Run func(cmd *Command, args []string)
|
||||
|
@ -218,8 +220,8 @@ Available Commands: {{range .Commands}}{{if .Runnable}}
|
|||
{{end}}
|
||||
{{ if .HasLocalFlags}}Flags:
|
||||
{{.LocalFlags.FlagUsages}}{{end}}
|
||||
{{ if .HasAnyPersistentFlags}}Global Flags:
|
||||
{{.AllPersistentFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }}
|
||||
{{ if .HasInheritedFlags}}Global Flags:
|
||||
{{.InheritedFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }}
|
||||
Additional help topics: {{if gt .Commands 0 }}{{range .Commands}}{{if not .Runnable}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if gt .Parent.Commands 1 }}{{range .Parent.Commands}}{{if .Runnable}}{{if not (eq .Name $cmd.Name) }}{{end}}
|
||||
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{end}}
|
||||
{{end}}{{ if .HasSubCommands }}
|
||||
|
@ -726,14 +728,9 @@ func (c *Command) LocalFlags() *flag.FlagSet {
|
|||
c.mergePersistentFlags()
|
||||
|
||||
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
allPersistent := c.AllPersistentFlags()
|
||||
|
||||
c.Flags().VisitAll(func(f *flag.Flag) {
|
||||
if allPersistent.Lookup(f.Name) == nil {
|
||||
local.AddFlag(f)
|
||||
}
|
||||
c.lflags.VisitAll(func(f *flag.Flag) {
|
||||
local.AddFlag(f)
|
||||
})
|
||||
|
||||
return local
|
||||
}
|
||||
|
||||
|
@ -741,44 +738,34 @@ func (c *Command) LocalFlags() *flag.FlagSet {
|
|||
func (c *Command) InheritedFlags() *flag.FlagSet {
|
||||
c.mergePersistentFlags()
|
||||
|
||||
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
inherited := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
local := c.LocalFlags()
|
||||
|
||||
var rmerge func(x *Command)
|
||||
var rmerge func(x *Command)
|
||||
|
||||
rmerge = func(x *Command) {
|
||||
if x.HasPersistentFlags() {
|
||||
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
|
||||
if local.Lookup(f.Name) == nil {
|
||||
local.AddFlag(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
if x.HasParent() {
|
||||
rmerge(x.parent)
|
||||
}
|
||||
}
|
||||
rmerge = func(x *Command) {
|
||||
if x.HasPersistentFlags() {
|
||||
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
|
||||
if inherited.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil {
|
||||
inherited.AddFlag(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
if x.HasParent() {
|
||||
rmerge(x.parent)
|
||||
}
|
||||
}
|
||||
|
||||
if c.HasParent() {
|
||||
rmerge(c.parent)
|
||||
}
|
||||
|
||||
return local
|
||||
return inherited
|
||||
}
|
||||
|
||||
// All Flags which were not inherited from parent commands
|
||||
func (c *Command) NonInheritedFlags() *flag.FlagSet {
|
||||
c.mergePersistentFlags()
|
||||
|
||||
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
inheritedFlags := c.InheritedFlags()
|
||||
|
||||
c.Flags().VisitAll(func(f *flag.Flag) {
|
||||
if inheritedFlags.Lookup(f.Name) == nil {
|
||||
local.AddFlag(f)
|
||||
}
|
||||
})
|
||||
|
||||
return local
|
||||
return c.LocalFlags()
|
||||
}
|
||||
|
||||
// Get the Persistent FlagSet specifically set in the current command
|
||||
|
@ -793,29 +780,6 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
|
|||
return c.pflags
|
||||
}
|
||||
|
||||
// Get the Persistent FlagSet traversing the Command hierarchy
|
||||
func (c *Command) AllPersistentFlags() *flag.FlagSet {
|
||||
allPersistent := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
|
||||
var visit func(x *Command)
|
||||
visit = func(x *Command) {
|
||||
if x.HasPersistentFlags() {
|
||||
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
|
||||
if allPersistent.Lookup(f.Name) == nil {
|
||||
allPersistent.AddFlag(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
if x.HasParent() {
|
||||
visit(x.parent)
|
||||
}
|
||||
}
|
||||
|
||||
visit(c)
|
||||
|
||||
return allPersistent
|
||||
}
|
||||
|
||||
// For use in testing
|
||||
func (c *Command) ResetFlags() {
|
||||
c.flagErrorBuf = new(bytes.Buffer)
|
||||
|
@ -836,16 +800,15 @@ func (c *Command) HasPersistentFlags() bool {
|
|||
return c.PersistentFlags().HasFlags()
|
||||
}
|
||||
|
||||
// Does the command hierarchy contain persistent flags
|
||||
func (c *Command) HasAnyPersistentFlags() bool {
|
||||
return c.AllPersistentFlags().HasFlags()
|
||||
}
|
||||
|
||||
// Does the command has flags specifically declared locally
|
||||
func (c *Command) HasLocalFlags() bool {
|
||||
return c.LocalFlags().HasFlags()
|
||||
}
|
||||
|
||||
func (c *Command) HasInheritedFlags() bool {
|
||||
return c.InheritedFlags().HasFlags()
|
||||
}
|
||||
|
||||
// Climbs up the command tree looking for matching flag
|
||||
func (c *Command) Flag(name string) (flag *flag.Flag) {
|
||||
flag = c.Flags().Lookup(name)
|
||||
|
@ -892,6 +855,19 @@ func (c *Command) Parent() *Command {
|
|||
func (c *Command) mergePersistentFlags() {
|
||||
var rmerge func(x *Command)
|
||||
|
||||
// Save the set of local flags
|
||||
if c.lflags == nil {
|
||||
c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
if c.flagErrorBuf == nil {
|
||||
c.flagErrorBuf = new(bytes.Buffer)
|
||||
}
|
||||
c.lflags.SetOutput(c.flagErrorBuf)
|
||||
addtolocal := func(f *flag.Flag) {
|
||||
c.lflags.AddFlag(f)
|
||||
}
|
||||
c.Flags().VisitAll(addtolocal)
|
||||
c.PersistentFlags().VisitAll(addtolocal)
|
||||
}
|
||||
rmerge = func(x *Command) {
|
||||
if x.HasPersistentFlags() {
|
||||
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
|
||||
|
|
Loading…
Reference in a new issue