Make parentsPflags more flexible

updateParentsPflags returns nothing, so you can use it independent of
mergePersistentFlags. A little performance impact.
This commit is contained in:
Albert Nigmatzianov 2017-04-19 11:17:48 +02:00
parent 3e61377cd5
commit e135867f96

View file

@ -1089,16 +1089,14 @@ func (c *Command) LocalFlags() *flag.FlagSet {
c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.lflags.SetOutput(c.OutOrStderr()) c.lflags.SetOutput(c.OutOrStderr())
} }
c.lflags.SortFlags = c.Flags().SortFlags
flags := c.Flags()
c.lflags.SortFlags = flags.SortFlags
addToLocal := func(f *flag.Flag) { addToLocal := func(f *flag.Flag) {
if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil {
c.lflags.AddFlag(f) c.lflags.AddFlag(f)
} }
} }
c.flags.VisitAll(addToLocal) c.Flags().VisitAll(addToLocal)
c.PersistentFlags().VisitAll(addToLocal) c.PersistentFlags().VisitAll(addToLocal)
return c.lflags return c.lflags
} }
@ -1112,13 +1110,11 @@ func (c *Command) InheritedFlags() *flag.FlagSet {
} }
local := c.LocalFlags() local := c.LocalFlags()
c.parentsPflags.VisitAll(func(f *flag.Flag) { c.parentsPflags.VisitAll(func(f *flag.Flag) {
if c.iflags.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { if c.iflags.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil {
c.iflags.AddFlag(f) c.iflags.AddFlag(f)
} }
}) })
return c.iflags return c.iflags
} }
@ -1204,8 +1200,9 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
flag = c.PersistentFlags().Lookup(name) flag = c.PersistentFlags().Lookup(name)
} }
if flag == nil && c.HasParent() { if flag == nil {
flag = c.parent.persistentFlag(name) c.updateParentsPflags()
flag = c.parentsPflags.Lookup(name)
} }
return return
} }
@ -1229,23 +1226,14 @@ func (c *Command) Parent() *Command {
// and adds missing persistent flags of all parents. // and adds missing persistent flags of all parents.
func (c *Command) mergePersistentFlags() { func (c *Command) mergePersistentFlags() {
c.Flags().AddFlagSet(c.PersistentFlags()) c.Flags().AddFlagSet(c.PersistentFlags())
c.updateParentsPflags()
added := c.updateParentsPflags() c.Flags().AddFlagSet(c.parentsPflags)
if len(added) > 0 {
for _, f := range added {
if c.Flags().Lookup(f.Name) == nil {
c.Flags().AddFlag(f)
}
}
}
} }
// updateParentsPflags updates c.parentsPflags by adding // updateParentsPflags updates c.parentsPflags by adding
// new persistent flags of all parents and returns added flags. // new persistent flags of all parents.
// If c.parentsPflags == nil, it makes new. // If c.parentsPflags == nil, it makes new.
// func (c *Command) updateParentsPflags() {
// This function must be used ONLY in mergePersistentFlags.
func (c *Command) updateParentsPflags() (added []*flag.Flag) {
if c.parentsPflags == nil { if c.parentsPflags == nil {
c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.parentsPflags.SetOutput(c.OutOrStderr()) c.parentsPflags.SetOutput(c.OutOrStderr())
@ -1254,16 +1242,7 @@ func (c *Command) updateParentsPflags() (added []*flag.Flag) {
c.Root().PersistentFlags().AddFlagSet(flag.CommandLine) c.Root().PersistentFlags().AddFlagSet(flag.CommandLine)
c.VisitParents(func(x *Command) { c.VisitParents(func(parent *Command) {
if x.HasPersistentFlags() { c.parentsPflags.AddFlagSet(parent.PersistentFlags())
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
if c.parentsPflags.Lookup(f.Name) == nil {
c.parentsPflags.AddFlag(f)
added = append(added, f)
}
})
}
}) })
return added
} }