mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Create and utilize mergePersistentFlags method
This commit is contained in:
parent
ed6206272d
commit
061ba30a84
1 changed files with 25 additions and 19 deletions
44
cobra.go
44
cobra.go
|
@ -19,7 +19,7 @@ package cobra
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
flag "github.com/ogier/pflag"
|
||||
flag "github.com/spf13/pflag"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
@ -252,7 +252,7 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
|
|||
}
|
||||
c.pflags.SetOutput(c.flagErrorBuf)
|
||||
}
|
||||
return c.flags
|
||||
return c.pflags
|
||||
}
|
||||
|
||||
// Intended for use in testing
|
||||
|
@ -265,22 +265,11 @@ func (c *Command) ResetFlags() {
|
|||
}
|
||||
|
||||
func (c *Command) HasFlags() bool {
|
||||
return hasFlags(c.flags)
|
||||
return c.Flags().HasFlags()
|
||||
}
|
||||
|
||||
func (c *Command) HasPersistentFlags() bool {
|
||||
return hasFlags(c.pflags)
|
||||
}
|
||||
|
||||
// Is this set of flags not empty
|
||||
func hasFlags(f *flag.FlagSet) bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
if f.NFlag() != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return c.PersistentFlags().HasFlags()
|
||||
}
|
||||
|
||||
// Climbs up the command tree looking for matching flag
|
||||
|
@ -308,10 +297,8 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
|
|||
|
||||
// Parses persistent flag tree & local flags
|
||||
func (c *Command) ParseFlags(args []string) (err error) {
|
||||
err = c.ParsePersistentFlags(args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.mergePersistentFlags()
|
||||
|
||||
err = c.Flags().Parse(args)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -319,6 +306,25 @@ func (c *Command) ParseFlags(args []string) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Command) mergePersistentFlags() {
|
||||
var rmerge func(x *Command)
|
||||
|
||||
rmerge = func(x *Command) {
|
||||
if x.HasPersistentFlags() {
|
||||
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
|
||||
if c.Flags().Lookup(f.Name) == nil {
|
||||
c.Flags().AddFlag(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
if x.HasParent() {
|
||||
rmerge(x.parent)
|
||||
}
|
||||
}
|
||||
|
||||
rmerge(c)
|
||||
}
|
||||
|
||||
// Climbs up the command tree parsing flags from top to bottom
|
||||
func (c *Command) ParsePersistentFlags(args []string) (err error) {
|
||||
if !c.HasParent() || (c.parent.HasPersistentFlags() && c.parent.PersistentFlags().Parsed()) {
|
||||
|
|
Loading…
Reference in a new issue