diff --git a/command.go b/command.go index 675bb134..ad2254cf 100644 --- a/command.go +++ b/command.go @@ -139,6 +139,8 @@ type Command struct { iflags *flag.FlagSet // parentsPflags is all persistent flags of cmd's parents. parentsPflags *flag.FlagSet + // lnamedFlagSets contains local named flags. + lnamedFlagSets *NamedFlagSets // globNormFunc is the global normalization function // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName @@ -514,7 +516,8 @@ Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "he {{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}} Flags: -{{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}} +{{.LocalNonNamedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableNamedFlags}} +{{.NamedFlagSets.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}} Global Flags: {{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}} @@ -578,6 +581,7 @@ func stripFlags(args []string, c *Command) []string { return args } c.mergePersistentFlags() + c.mergeNamedFlags() commands := []string{} flags := c.Flags() @@ -1046,6 +1050,7 @@ func (c *Command) validateRequiredFlags() error { // If c already has help flag, it will do nothing. func (c *Command) InitDefaultHelpFlag() { c.mergePersistentFlags() + c.mergeNamedFlags() if c.Flags().Lookup("help") == nil { usage := "help for " if c.Name() == "" { @@ -1067,6 +1072,7 @@ func (c *Command) InitDefaultVersionFlag() { } c.mergePersistentFlags() + c.mergeNamedFlags() if c.Flags().Lookup("version") == nil { usage := "version for " if c.Name() == "" { @@ -1475,6 +1481,31 @@ func (c *Command) Flags() *flag.FlagSet { return c.flags } +// NamedFlagSets returns all the named FlagSet that applies to this command. +func (c *Command) NamedFlagSets() *NamedFlagSets { + if c.lnamedFlagSets == nil { + c.lnamedFlagSets = NewNamedFlagSets(c.Name(), flag.ContinueOnError) + } + return c.lnamedFlagSets +} + +// NamedFlags returns the specific named FlagSet that applies to this command. +func (c *Command) NamedFlags(name string) *flag.FlagSet { + nfs := c.NamedFlagSets() + flags, ok := nfs.FlagSet(name) + if !ok { + if c.flagErrorBuf == nil { + c.flagErrorBuf = new(bytes.Buffer) + } + flags.SetOutput(c.flagErrorBuf) + if c.globNormFunc != nil { + flags.SetNormalizeFunc(c.globNormFunc) + } + } + + return flags +} + // LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands. func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { persistentFlags := c.PersistentFlags() @@ -1488,9 +1519,23 @@ func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { return out } +// LocalNonNamedFlags are flags specific to this command which are NOT named. +func (c *Command) LocalNonNamedFlags() *flag.FlagSet { + namedFlags := c.NamedFlagSets().Flatten() + + out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.LocalFlags().VisitAll(func(f *flag.Flag) { + if namedFlags.Lookup(f.Name) == nil { + out.AddFlag(f) + } + }) + return out +} + // LocalFlags returns the local FlagSet specifically set in the current command. func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() + c.mergeNamedFlags() if c.lflags == nil { c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) @@ -1517,6 +1562,7 @@ func (c *Command) LocalFlags() *flag.FlagSet { // InheritedFlags returns all flags which were inherited from parent commands. func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() + c.mergeNamedFlags() if c.iflags == nil { c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) @@ -1568,6 +1614,7 @@ func (c *Command) ResetFlags() { c.lflags = nil c.iflags = nil c.parentsPflags = nil + c.lnamedFlagSets = nil } // HasFlags checks if the command contains any flags (local plus persistent from the entire structure). @@ -1596,6 +1643,11 @@ func (c *Command) HasAvailableFlags() bool { return c.Flags().HasAvailableFlags() } +// HasAvailableNamedFlags checks if the command contains any named flags which are not hidden or deprecated. +func (c *Command) HasAvailableNamedFlags() bool { + return c.NamedFlagSets().Flatten().HasAvailableFlags() +} + // HasAvailablePersistentFlags checks if the command contains persistent flags which are not hidden or deprecated. func (c *Command) HasAvailablePersistentFlags() bool { return c.PersistentFlags().HasAvailableFlags() @@ -1648,6 +1700,7 @@ func (c *Command) ParseFlags(args []string) error { } beforeErrorBufLen := c.flagErrorBuf.Len() c.mergePersistentFlags() + c.mergeNamedFlags() // do it here after merging all flags and just before parse c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) @@ -1674,6 +1727,11 @@ func (c *Command) mergePersistentFlags() { c.Flags().AddFlagSet(c.parentsPflags) } +// mergeNamedFlags merges c.NamedFlagSets() to c.Flags() +func (c *Command) mergeNamedFlags() { + c.Flags().AddFlagSet(c.NamedFlagSets().Flatten()) +} + // updateParentsPflags updates c.parentsPflags by adding // new persistent flags of all parents. // If c.parentsPflags == nil, it makes new. diff --git a/named_flag_sets.go b/named_flag_sets.go new file mode 100644 index 00000000..4681a1c5 --- /dev/null +++ b/named_flag_sets.go @@ -0,0 +1,89 @@ +package cobra + +import ( + "bytes" + "fmt" + "strings" + + "github.com/spf13/pflag" +) + +// NamedFlagSets stores named flag sets in the order of calling FlagSet. +type NamedFlagSets struct { + name string + errorHandling pflag.ErrorHandling + + // order is an ordered list of flag set names. + order []string + // flagSets stores the flag sets by name. + flagSets map[string]*pflag.FlagSet +} + +func NewNamedFlagSets(name string, errorHandling pflag.ErrorHandling) *NamedFlagSets { + return &NamedFlagSets{ + name: name, + errorHandling: errorHandling, + } +} + +// FlagSet returns the flag set with the given name and adds it to the +// ordered name list if it is not in there yet. +func (nfs *NamedFlagSets) FlagSet(name string) (*pflag.FlagSet, bool) { + if nfs.flagSets == nil { + nfs.flagSets = map[string]*pflag.FlagSet{} + } + var ok bool + if _, ok = nfs.flagSets[name]; !ok { + flagSet := pflag.NewFlagSet(name, nfs.errorHandling) + nfs.flagSets[name] = flagSet + nfs.order = append(nfs.order, name) + } + return nfs.flagSets[name], ok +} + +// Flatten returns a single flag set containing all the flag sets +// in the NamedFlagSet +func (nfs *NamedFlagSets) Flatten() *pflag.FlagSet { + out := pflag.NewFlagSet(nfs.name, nfs.errorHandling) + for _, fs := range nfs.flagSets { + out.AddFlagSet(fs) + } + return out +} + +// FlagUsages returns a string containing the usage information for all flags in +// the FlagSet +func (nfs *NamedFlagSets) FlagUsages() string { + return nfs.FlagUsagesWrapped(0) +} + +func (nfs *NamedFlagSets) FlagUsagesWrapped(cols int) string { + var buf bytes.Buffer + for _, name := range nfs.order { + fs := nfs.flagSets[name] + if !fs.HasFlags() { + continue + } + + wideFS := pflag.NewFlagSet("", pflag.ExitOnError) + wideFS.AddFlagSet(fs) + + var zzz string + if cols > 24 { + zzz = strings.Repeat("z", cols-24) + wideFS.Int(zzz, 0, strings.Repeat("z", cols-24)) + } + + s := fmt.Sprintf("\n%s Flags:\n%s", strings.ToUpper(name[:1])+name[1:], wideFS.FlagUsagesWrapped(cols)) + + if cols > 24 { + i := strings.Index(s, zzz) + lines := strings.Split(s[:i], "\n") + fmt.Fprint(&buf, strings.Join(lines[:len(lines)-1], "\n")) + fmt.Fprintln(&buf) + } else { + fmt.Fprint(&buf, s) + } + } + return buf.String() +}