diff --git a/cobra_test.go b/cobra_test.go index 629865c2..90e88982 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -11,11 +11,14 @@ var _ = fmt.Println var tp, te, tt, t1 []string var flagb1, flagb2, flagb3, flagbr bool -var flags1, flags2, flags3 string +var flags1, flags2a, flags2b, flags3 string var flagi1, flagi2, flagi3, flagir int var globalFlag1 bool var flagEcho, rootcalled bool +const strtwoParentHelp = "help message for parent flag strtwo" +const strtwoChildHelp = "help message for child flag strtwo" + var cmdPrint = &Command{ Use: "print [string to print]", Short: "Print anything to the screen", @@ -72,11 +75,12 @@ func flagInit() { cmdRootNoRun.ResetFlags() cmdRootSameName.ResetFlags() cmdRootWithRun.ResetFlags() + cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp) cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") - cmdTimes.PersistentFlags().StringVarP(&flags2, "strtwo", "t", "two", "help message for flag strtwo") + cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") @@ -381,6 +385,17 @@ func TestChildCommandFlags(t *testing.T) { t.Errorf("Wrong error message displayed, \n %s", r.Output) } + // Testing with persistent flag overwritten by child + noRRSetupTest("echo times --strtwo=child one two") + + if flags2b != "child" { + t.Errorf("flag value should be child, %s given", flags2b) + } + + if flags2a != "two" { + t.Errorf("unset flag should have default value, expecting two, given %s", flags2a) + } + // Testing flag with invalid input r = noRRSetupTest("echo -i10E") @@ -437,6 +452,13 @@ func TestHelpCommand(t *testing.T) { checkResultContains(t, r, cmdTimes.Long) } +func TestChildCommandHelp(t *testing.T) { + c := noRRSetupTest("print --help") + checkResultContains(t, c, strtwoParentHelp) + r := noRRSetupTest("echo times --help") + checkResultContains(t, r, strtwoChildHelp) +} + func TestRunnableRootCommand(t *testing.T) { fullSetupTest("") @@ -486,6 +508,26 @@ func TestRootHelp(t *testing.T) { } +func TestFlagAccess(t *testing.T) { + initialize() + + local := cmdTimes.LocalFlags() + inherited := cmdTimes.InheritedFlags() + + for _, f := range []string{"inttwo", "strtwo", "booltwo"} { + if local.Lookup(f) == nil { + t.Errorf("LocalFlags expected to contain %s, Got: nil", f) + } + } + if inherited.Lookup("strone") == nil { + t.Errorf("InheritedFlags expected to contain strone, Got: nil") + } + if inherited.Lookup("strtwo") != nil { + t.Errorf("InheritedFlags shouldn not contain overwritten flag strtwo") + + } +} + func TestRootNoCommandHelp(t *testing.T) { x := rootOnlySetupTest("--help") diff --git a/command.go b/command.go index 6ac3ea70..fa9b9e5e 100644 --- a/command.go +++ b/command.go @@ -46,6 +46,8 @@ type Command struct { flags *flag.FlagSet // Set of flags childrens of this command will inherit pflags *flag.FlagSet + // Flags that are declared specifically by this command (not inherited). + lflags *flag.FlagSet // Run runs the command. // The args are the arguments after the command name. Run func(cmd *Command, args []string) @@ -218,8 +220,8 @@ Available Commands: {{range .Commands}}{{if .Runnable}} {{end}} {{ if .HasLocalFlags}}Flags: {{.LocalFlags.FlagUsages}}{{end}} -{{ if .HasAnyPersistentFlags}}Global Flags: -{{.AllPersistentFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }} +{{ if .HasInheritedFlags}}Global Flags: +{{.InheritedFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }} Additional help topics: {{if gt .Commands 0 }}{{range .Commands}}{{if not .Runnable}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if gt .Parent.Commands 1 }}{{range .Parent.Commands}}{{if .Runnable}}{{if not (eq .Name $cmd.Name) }}{{end}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{end}} {{end}}{{ if .HasSubCommands }} @@ -726,14 +728,9 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) - allPersistent := c.AllPersistentFlags() - - c.Flags().VisitAll(func(f *flag.Flag) { - if allPersistent.Lookup(f.Name) == nil { - local.AddFlag(f) - } + c.lflags.VisitAll(func(f *flag.Flag) { + local.AddFlag(f) }) - return local } @@ -741,44 +738,34 @@ func (c *Command) LocalFlags() *flag.FlagSet { func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() - local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + inherited := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + local := c.LocalFlags() - var rmerge func(x *Command) + var rmerge func(x *Command) - rmerge = func(x *Command) { - if x.HasPersistentFlags() { - x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if local.Lookup(f.Name) == nil { - local.AddFlag(f) - } - }) - } - if x.HasParent() { - rmerge(x.parent) - } - } + rmerge = func(x *Command) { + if x.HasPersistentFlags() { + x.PersistentFlags().VisitAll(func(f *flag.Flag) { + if inherited.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { + inherited.AddFlag(f) + } + }) + } + if x.HasParent() { + rmerge(x.parent) + } + } if c.HasParent() { rmerge(c.parent) } - return local + return inherited } // All Flags which were not inherited from parent commands func (c *Command) NonInheritedFlags() *flag.FlagSet { - c.mergePersistentFlags() - - local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) - inheritedFlags := c.InheritedFlags() - - c.Flags().VisitAll(func(f *flag.Flag) { - if inheritedFlags.Lookup(f.Name) == nil { - local.AddFlag(f) - } - }) - - return local + return c.LocalFlags() } // Get the Persistent FlagSet specifically set in the current command @@ -793,29 +780,6 @@ func (c *Command) PersistentFlags() *flag.FlagSet { return c.pflags } -// Get the Persistent FlagSet traversing the Command hierarchy -func (c *Command) AllPersistentFlags() *flag.FlagSet { - allPersistent := flag.NewFlagSet(c.Name(), flag.ContinueOnError) - - var visit func(x *Command) - visit = func(x *Command) { - if x.HasPersistentFlags() { - x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if allPersistent.Lookup(f.Name) == nil { - allPersistent.AddFlag(f) - } - }) - } - if x.HasParent() { - visit(x.parent) - } - } - - visit(c) - - return allPersistent -} - // For use in testing func (c *Command) ResetFlags() { c.flagErrorBuf = new(bytes.Buffer) @@ -836,16 +800,15 @@ func (c *Command) HasPersistentFlags() bool { return c.PersistentFlags().HasFlags() } -// Does the command hierarchy contain persistent flags -func (c *Command) HasAnyPersistentFlags() bool { - return c.AllPersistentFlags().HasFlags() -} - // Does the command has flags specifically declared locally func (c *Command) HasLocalFlags() bool { return c.LocalFlags().HasFlags() } +func (c *Command) HasInheritedFlags() bool { + return c.InheritedFlags().HasFlags() +} + // Climbs up the command tree looking for matching flag func (c *Command) Flag(name string) (flag *flag.Flag) { flag = c.Flags().Lookup(name) @@ -892,6 +855,19 @@ func (c *Command) Parent() *Command { func (c *Command) mergePersistentFlags() { var rmerge func(x *Command) + // Save the set of local flags + if c.lflags == nil { + c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + if c.flagErrorBuf == nil { + c.flagErrorBuf = new(bytes.Buffer) + } + c.lflags.SetOutput(c.flagErrorBuf) + addtolocal := func(f *flag.Flag) { + c.lflags.AddFlag(f) + } + c.Flags().VisitAll(addtolocal) + c.PersistentFlags().VisitAll(addtolocal) + } rmerge = func(x *Command) { if x.HasPersistentFlags() { x.PersistentFlags().VisitAll(func(f *flag.Flag) {