diff --git a/.gitignore b/.gitignore index 36d1a84d..1b8c7c26 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,18 @@ _cgo_export.* _testmain.go +# Vim files https://github.com/github/gitignore/blob/master/Global/Vim.gitignore +# swap +[._]*.s[a-w][a-z] +[._]s[a-w][a-z] +# session +Session.vim +# temporary +.netrwhist +*~ +# auto-generated tag files +tags + *.exe cobra.test diff --git a/bash_completions.go b/bash_completions.go index 3f33bb0e..236dee67 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -116,12 +116,12 @@ __handle_reply() fi local completions - if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then - completions=("${must_have_one_flag[@]}") - elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then + completions=("${commands[@]}") + if [[ ${#must_have_one_noun[@]} -ne 0 ]]; then completions=("${must_have_one_noun[@]}") - else - completions=("${commands[@]}") + fi + if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then + completions+=("${must_have_one_flag[@]}") fi COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") ) @@ -167,6 +167,11 @@ __handle_flag() must_have_one_flag=() fi + # if you set a flag which only applies to this command, don't show subcommands + if __contains_word "${flagname}" "${local_nonpersistent_flags[@]}"; then + commands=() + fi + # keep flag value with flagname as flaghash if [ -n "${flagvalue}" ] ; then flaghash[${flagname}]=${flagvalue} @@ -263,6 +268,7 @@ func postscript(w io.Writer, name string) error { local c=0 local flags=() local two_word_flags=() + local local_nonpersistent_flags=() local flags_with_completion=() local flags_completion=() local commands=("%s") @@ -360,7 +366,7 @@ func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) } func writeShortFlag(flag *pflag.Flag, w io.Writer) error { - b := (flag.Value.Type() == "bool") + b := (len(flag.NoOptDefVal) > 0) name := flag.Shorthand format := " " if !b { @@ -374,7 +380,7 @@ func writeShortFlag(flag *pflag.Flag, w io.Writer) error { } func writeFlag(flag *pflag.Flag, w io.Writer) error { - b := (flag.Value.Type() == "bool") + b := (len(flag.NoOptDefVal) > 0) name := flag.Name format := " flags+=(\"--%s" if !b { @@ -387,9 +393,24 @@ func writeFlag(flag *pflag.Flag, w io.Writer) error { return writeFlagHandler("--"+name, flag.Annotations, w) } +func writeLocalNonPersistentFlag(flag *pflag.Flag, w io.Writer) error { + b := (len(flag.NoOptDefVal) > 0) + name := flag.Name + format := " local_nonpersistent_flags+=(\"--%s" + if !b { + format += "=" + } + format += "\")\n" + if _, err := fmt.Fprintf(w, format, name); err != nil { + return err + } + return nil +} + func writeFlags(cmd *Command, w io.Writer) error { _, err := fmt.Fprintf(w, ` flags=() two_word_flags=() + local_nonpersistent_flags=() flags_with_completion=() flags_completion=() @@ -397,6 +418,7 @@ func writeFlags(cmd *Command, w io.Writer) error { if err != nil { return err } + localNonPersistentFlags := cmd.LocalNonPersistentFlags() var visitErr error cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { if err := writeFlag(flag, w); err != nil { @@ -409,6 +431,12 @@ func writeFlags(cmd *Command, w io.Writer) error { return } } + if localNonPersistentFlags.Lookup(flag.Name) != nil { + if err := writeLocalNonPersistentFlag(flag, w); err != nil { + visitErr = err + return + } + } }) if visitErr != nil { return visitErr diff --git a/command.go b/command.go index e1517aa2..dc57d630 100644 --- a/command.go +++ b/command.go @@ -1032,6 +1032,19 @@ func (c *Command) Flags() *flag.FlagSet { return c.flags } +// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands +func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { + persistentFlags := c.PersistentFlags() + + out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.LocalFlags().VisitAll(func(f *flag.Flag) { + if persistentFlags.Lookup(f.Name) == nil { + out.AddFlag(f) + } + }) + return out +} + // Get the local FlagSet specifically set in the current command func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags()