Create and utilize mergePersistentFlags method

This commit is contained in:
spf13 2013-09-10 18:26:17 -04:00
parent ed6206272d
commit 061ba30a84

View file

@ -19,7 +19,7 @@ package cobra
import ( import (
"bytes" "bytes"
"fmt" "fmt"
flag "github.com/ogier/pflag" flag "github.com/spf13/pflag"
"os" "os"
"strings" "strings"
) )
@ -252,7 +252,7 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
} }
c.pflags.SetOutput(c.flagErrorBuf) c.pflags.SetOutput(c.flagErrorBuf)
} }
return c.flags return c.pflags
} }
// Intended for use in testing // Intended for use in testing
@ -265,22 +265,11 @@ func (c *Command) ResetFlags() {
} }
func (c *Command) HasFlags() bool { func (c *Command) HasFlags() bool {
return hasFlags(c.flags) return c.Flags().HasFlags()
} }
func (c *Command) HasPersistentFlags() bool { func (c *Command) HasPersistentFlags() bool {
return hasFlags(c.pflags) return c.PersistentFlags().HasFlags()
}
// Is this set of flags not empty
func hasFlags(f *flag.FlagSet) bool {
if f == nil {
return false
}
if f.NFlag() != 0 {
return true
}
return false
} }
// Climbs up the command tree looking for matching flag // Climbs up the command tree looking for matching flag
@ -308,10 +297,8 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
// Parses persistent flag tree & local flags // Parses persistent flag tree & local flags
func (c *Command) ParseFlags(args []string) (err error) { func (c *Command) ParseFlags(args []string) (err error) {
err = c.ParsePersistentFlags(args) c.mergePersistentFlags()
if err != nil {
return err
}
err = c.Flags().Parse(args) err = c.Flags().Parse(args)
if err != nil { if err != nil {
return err return err
@ -319,6 +306,25 @@ func (c *Command) ParseFlags(args []string) (err error) {
return nil return nil
} }
func (c *Command) mergePersistentFlags() {
var rmerge func(x *Command)
rmerge = func(x *Command) {
if x.HasPersistentFlags() {
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
if c.Flags().Lookup(f.Name) == nil {
c.Flags().AddFlag(f)
}
})
}
if x.HasParent() {
rmerge(x.parent)
}
}
rmerge(c)
}
// Climbs up the command tree parsing flags from top to bottom // Climbs up the command tree parsing flags from top to bottom
func (c *Command) ParsePersistentFlags(args []string) (err error) { func (c *Command) ParsePersistentFlags(args []string) (err error) {
if !c.HasParent() || (c.parent.HasPersistentFlags() && c.parent.PersistentFlags().Parsed()) { if !c.HasParent() || (c.parent.HasPersistentFlags() && c.parent.PersistentFlags().Parsed()) {