diff --git a/cobra_test.go b/cobra_test.go index 6c443e7a..b78a6566 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -675,7 +675,7 @@ func TestPersistentFlags(t *testing.T) { fullSetupTest("echo times -s again -c -p test here") if strings.Join(tt, " ") != "test here" { - t.Errorf("flags didn't leave proper args remaining..%s given", tt) + t.Errorf("flags didn't leave proper args remaining. %s given", tt) } if flags1 != "again" { diff --git a/command.go b/command.go index a4366562..52231c39 100644 --- a/command.go +++ b/command.go @@ -772,6 +772,7 @@ func (c *Command) initHelpCmd() { func (c *Command) ResetCommands() { c.commands = nil c.helpCommand = nil + c.parentsPflags = nil } // Sorts commands by their names. @@ -1083,7 +1084,6 @@ func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { // LocalFlags returns the local FlagSet specifically set in the current command. func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() - c.updateParentsPersistentFlags() if c.lflags == nil { c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) @@ -1106,7 +1106,6 @@ func (c *Command) LocalFlags() *flag.FlagSet { // InheritedFlags returns all flags which were inherited from parents commands. func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() - c.updateParentsPersistentFlags() if c.iflags == nil { c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) @@ -1226,10 +1225,37 @@ func (c *Command) Parent() *Command { return c.parent } -func (c *Command) updateParentsPersistentFlags() { +// mergePersistentFlags merges c.PersistentFlags() to c.Flags() +// and adds missing persistent flags of all parents. +func (c *Command) mergePersistentFlags() { + if c.HasPersistentFlags() { + c.PersistentFlags().VisitAll(func(f *flag.Flag) { + if c.Flags().Lookup(f.Name) == nil { + c.Flags().AddFlag(f) + } + }) + } + + added := c.updateParentsPflags() + if len(added) > 0 { + for _, f := range added { + if c.Flags().Lookup(f.Name) == nil { + c.Flags().AddFlag(f) + } + } + } +} + +// updateParentsPflags updates c.parentsPflags by adding +// new persistent flags of all parents and returns added flags. +// If c.parentsPflags == nil, it makes new. +// +// This function must be used ONLY in mergePersistentFlags. +func (c *Command) updateParentsPflags() (added []*flag.Flag) { if c.parentsPflags == nil { c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.parentsPflags.SetOutput(c.OutOrStderr()) + c.parentsPflags.SortFlags = false } c.VisitParents(func(x *Command) { @@ -1237,24 +1263,11 @@ func (c *Command) updateParentsPersistentFlags() { x.PersistentFlags().VisitAll(func(f *flag.Flag) { if c.parentsPflags.Lookup(f.Name) == nil { c.parentsPflags.AddFlag(f) + added = append(added, f) } }) } }) -} -func (c *Command) mergePersistentFlags() { - flags := c.Flags() - - merge := func(x *Command) { - if x.HasPersistentFlags() { - x.PersistentFlags().VisitAll(func(f *flag.Flag) { - if flags.Lookup(f.Name) == nil { - flags.AddFlag(f) - } - }) - } - } - merge(c) - c.VisitParents(merge) + return added }