mirror of
https://github.com/spf13/cobra
synced 2024-11-24 22:57:12 +00:00
Merge pull request #288 from eparis/flags-vs-commands
Do not display bash commands if local flag specified
This commit is contained in:
commit
c21ee9da52
3 changed files with 60 additions and 7 deletions
12
.gitignore
vendored
12
.gitignore
vendored
|
@ -19,6 +19,18 @@ _cgo_export.*
|
||||||
|
|
||||||
_testmain.go
|
_testmain.go
|
||||||
|
|
||||||
|
# Vim files https://github.com/github/gitignore/blob/master/Global/Vim.gitignore
|
||||||
|
# swap
|
||||||
|
[._]*.s[a-w][a-z]
|
||||||
|
[._]s[a-w][a-z]
|
||||||
|
# session
|
||||||
|
Session.vim
|
||||||
|
# temporary
|
||||||
|
.netrwhist
|
||||||
|
*~
|
||||||
|
# auto-generated tag files
|
||||||
|
tags
|
||||||
|
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
cobra.test
|
cobra.test
|
||||||
|
|
|
@ -116,12 +116,12 @@ __handle_reply()
|
||||||
fi
|
fi
|
||||||
|
|
||||||
local completions
|
local completions
|
||||||
if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
|
|
||||||
completions=("${must_have_one_flag[@]}")
|
|
||||||
elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
|
|
||||||
completions=("${must_have_one_noun[@]}")
|
|
||||||
else
|
|
||||||
completions=("${commands[@]}")
|
completions=("${commands[@]}")
|
||||||
|
if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
|
||||||
|
completions=("${must_have_one_noun[@]}")
|
||||||
|
fi
|
||||||
|
if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
|
||||||
|
completions+=("${must_have_one_flag[@]}")
|
||||||
fi
|
fi
|
||||||
COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
|
COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
|
||||||
|
|
||||||
|
@ -167,6 +167,11 @@ __handle_flag()
|
||||||
must_have_one_flag=()
|
must_have_one_flag=()
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# if you set a flag which only applies to this command, don't show subcommands
|
||||||
|
if __contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then
|
||||||
|
commands=()
|
||||||
|
fi
|
||||||
|
|
||||||
# keep flag value with flagname as flaghash
|
# keep flag value with flagname as flaghash
|
||||||
if [ -n "${flagvalue}" ] ; then
|
if [ -n "${flagvalue}" ] ; then
|
||||||
flaghash[${flagname}]=${flagvalue}
|
flaghash[${flagname}]=${flagvalue}
|
||||||
|
@ -263,6 +268,7 @@ func postscript(w io.Writer, name string) error {
|
||||||
local c=0
|
local c=0
|
||||||
local flags=()
|
local flags=()
|
||||||
local two_word_flags=()
|
local two_word_flags=()
|
||||||
|
local local_nonpersistent_flags=()
|
||||||
local flags_with_completion=()
|
local flags_with_completion=()
|
||||||
local flags_completion=()
|
local flags_completion=()
|
||||||
local commands=("%s")
|
local commands=("%s")
|
||||||
|
@ -360,7 +366,7 @@ func writeFlagHandler(name string, annotations map[string][]string, w io.Writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeShortFlag(flag *pflag.Flag, w io.Writer) error {
|
func writeShortFlag(flag *pflag.Flag, w io.Writer) error {
|
||||||
b := (flag.Value.Type() == "bool")
|
b := (len(flag.NoOptDefVal) > 0)
|
||||||
name := flag.Shorthand
|
name := flag.Shorthand
|
||||||
format := " "
|
format := " "
|
||||||
if !b {
|
if !b {
|
||||||
|
@ -374,7 +380,7 @@ func writeShortFlag(flag *pflag.Flag, w io.Writer) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeFlag(flag *pflag.Flag, w io.Writer) error {
|
func writeFlag(flag *pflag.Flag, w io.Writer) error {
|
||||||
b := (flag.Value.Type() == "bool")
|
b := (len(flag.NoOptDefVal) > 0)
|
||||||
name := flag.Name
|
name := flag.Name
|
||||||
format := " flags+=(\"--%s"
|
format := " flags+=(\"--%s"
|
||||||
if !b {
|
if !b {
|
||||||
|
@ -387,9 +393,24 @@ func writeFlag(flag *pflag.Flag, w io.Writer) error {
|
||||||
return writeFlagHandler("--"+name, flag.Annotations, w)
|
return writeFlagHandler("--"+name, flag.Annotations, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeLocalNonPersistentFlag(flag *pflag.Flag, w io.Writer) error {
|
||||||
|
b := (len(flag.NoOptDefVal) > 0)
|
||||||
|
name := flag.Name
|
||||||
|
format := " local_nonpersistent_flags+=(\"--%s"
|
||||||
|
if !b {
|
||||||
|
format += "="
|
||||||
|
}
|
||||||
|
format += "\")\n"
|
||||||
|
if _, err := fmt.Fprintf(w, format, name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func writeFlags(cmd *Command, w io.Writer) error {
|
func writeFlags(cmd *Command, w io.Writer) error {
|
||||||
_, err := fmt.Fprintf(w, ` flags=()
|
_, err := fmt.Fprintf(w, ` flags=()
|
||||||
two_word_flags=()
|
two_word_flags=()
|
||||||
|
local_nonpersistent_flags=()
|
||||||
flags_with_completion=()
|
flags_with_completion=()
|
||||||
flags_completion=()
|
flags_completion=()
|
||||||
|
|
||||||
|
@ -397,6 +418,7 @@ func writeFlags(cmd *Command, w io.Writer) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
localNonPersistentFlags := cmd.LocalNonPersistentFlags()
|
||||||
var visitErr error
|
var visitErr error
|
||||||
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
|
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
|
||||||
if err := writeFlag(flag, w); err != nil {
|
if err := writeFlag(flag, w); err != nil {
|
||||||
|
@ -409,6 +431,12 @@ func writeFlags(cmd *Command, w io.Writer) error {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if localNonPersistentFlags.Lookup(flag.Name) != nil {
|
||||||
|
if err := writeLocalNonPersistentFlag(flag, w); err != nil {
|
||||||
|
visitErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
if visitErr != nil {
|
if visitErr != nil {
|
||||||
return visitErr
|
return visitErr
|
||||||
|
|
13
command.go
13
command.go
|
@ -1032,6 +1032,19 @@ func (c *Command) Flags() *flag.FlagSet {
|
||||||
return c.flags
|
return c.flags
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands
|
||||||
|
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
|
||||||
|
persistentFlags := c.PersistentFlags()
|
||||||
|
|
||||||
|
out := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||||
|
c.LocalFlags().VisitAll(func(f *flag.Flag) {
|
||||||
|
if persistentFlags.Lookup(f.Name) == nil {
|
||||||
|
out.AddFlag(f)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// Get the local FlagSet specifically set in the current command
|
// Get the local FlagSet specifically set in the current command
|
||||||
func (c *Command) LocalFlags() *flag.FlagSet {
|
func (c *Command) LocalFlags() *flag.FlagSet {
|
||||||
c.mergePersistentFlags()
|
c.mergePersistentFlags()
|
||||||
|
|
Loading…
Reference in a new issue