Simplify bash_completions.go

Improve test coverage from 80% to 85%.
This commit is contained in:
Albert Nigmatzianov 2017-05-14 13:20:20 +02:00
parent 4cdb38c072
commit de6b168d98
2 changed files with 78 additions and 193 deletions

View file

@ -1,6 +1,7 @@
package cobra package cobra
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -18,12 +19,9 @@ const (
BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
) )
func preamble(out io.Writer, name string) error { func writePreamble(buf *bytes.Buffer, name string) {
_, err := fmt.Fprintf(out, "# bash completion for %-36s -*- shell-script -*-\n", name) buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name))
if err != nil { buf.WriteString(`
return err
}
preamStr := `
__debug() __debug()
{ {
if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
@ -247,18 +245,13 @@ __handle_word()
__handle_word __handle_word
} }
` `)
_, err = fmt.Fprint(out, preamStr)
return err
} }
func postscript(w io.Writer, name string) error { func writePostscript(buf *bytes.Buffer, name string) {
name = strings.Replace(name, ":", "__", -1) name = strings.Replace(name, ":", "__", -1)
_, err := fmt.Fprintf(w, "__start_%s()\n", name) buf.WriteString(fmt.Sprintf("__start_%s()\n", name))
if err != nil { buf.WriteString(fmt.Sprintf(`{
return err
}
_, err = fmt.Fprintf(w, `{
local cur prev words cword local cur prev words cword
declare -A flaghash 2>/dev/null || : declare -A flaghash 2>/dev/null || :
if declare -F _init_completion >/dev/null 2>&1; then if declare -F _init_completion >/dev/null 2>&1; then
@ -282,197 +275,132 @@ func postscript(w io.Writer, name string) error {
__handle_word __handle_word
} }
`, name) `, name))
if err != nil { buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then
return err
}
_, err = fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then
complete -o default -F __start_%s %s complete -o default -F __start_%s %s
else else
complete -o default -o nospace -F __start_%s %s complete -o default -o nospace -F __start_%s %s
fi fi
`, name, name, name, name) `, name, name, name, name))
if err != nil { buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n")
return err
}
_, err = fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n")
return err
} }
func writeCommands(cmd *Command, w io.Writer) error { func writeCommands(buf *bytes.Buffer, cmd *Command) {
if _, err := fmt.Fprintf(w, " commands=()\n"); err != nil { buf.WriteString(" commands=()\n")
return err
}
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
if _, err := fmt.Fprintf(w, " commands+=(%q)\n", c.Name()); err != nil { buf.WriteString(fmt.Sprintf(" commands+=(%q)\n", c.Name()))
return err
} }
} buf.WriteString("\n")
_, err := fmt.Fprintf(w, "\n")
return err
} }
func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) error { func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string) {
for key, value := range annotations { for key, value := range annotations {
switch key { switch key {
case BashCompFilenameExt: case BashCompFilenameExt:
_, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
if err != nil {
return err
}
var ext string
if len(value) > 0 { if len(value) > 0 {
ext := "__handle_filename_extension_flag " + strings.Join(value, "|") ext = "__handle_filename_extension_flag " + strings.Join(value, "|")
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else { } else {
ext := "_filedir" ext = "_filedir"
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
}
if err != nil {
return err
} }
buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
case BashCompCustom: case BashCompCustom:
_, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
if err != nil {
return err
}
if len(value) > 0 { if len(value) > 0 {
handlers := strings.Join(value, "; ") handlers := strings.Join(value, "; ")
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", handlers) buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", handlers))
} else { } else {
_, err = fmt.Fprintf(w, " flags_completion+=(:)\n") buf.WriteString(" flags_completion+=(:)\n")
}
if err != nil {
return err
} }
case BashCompSubdirsInDir: case BashCompSubdirsInDir:
_, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name))
var ext string
if len(value) == 1 { if len(value) == 1 {
ext := "__handle_subdirs_in_dir_flag " + value[0] ext = "__handle_subdirs_in_dir_flag " + value[0]
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else { } else {
ext := "_filedir -d" ext = "_filedir -d"
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} }
if err != nil { buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
return err
} }
} }
} }
return nil
}
func writeShortFlag(flag *pflag.Flag, w io.Writer) error { func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag) {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Shorthand name := flag.Shorthand
format := " " format := " "
if !b { if len(flag.NoOptDefVal) == 0 {
format += "two_word_" format += "two_word_"
} }
format += "flags+=(\"-%s\")\n" format += "flags+=(\"-%s\")\n"
if _, err := fmt.Fprintf(w, format, name); err != nil { buf.WriteString(fmt.Sprintf(format, name))
return err writeFlagHandler(buf, "-"+name, flag.Annotations)
}
return writeFlagHandler("-"+name, flag.Annotations, w)
} }
func writeFlag(flag *pflag.Flag, w io.Writer) error { func writeFlag(buf *bytes.Buffer, flag *pflag.Flag) {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Name name := flag.Name
format := " flags+=(\"--%s" format := " flags+=(\"--%s"
if !b { if len(flag.NoOptDefVal) == 0 {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
if _, err := fmt.Fprintf(w, format, name); err != nil { buf.WriteString(fmt.Sprintf(format, name))
return err writeFlagHandler(buf, "--"+name, flag.Annotations)
}
return writeFlagHandler("--"+name, flag.Annotations, w)
} }
func writeLocalNonPersistentFlag(flag *pflag.Flag, w io.Writer) error { func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Name name := flag.Name
format := " local_nonpersistent_flags+=(\"--%s" format := " local_nonpersistent_flags+=(\"--%s"
if !b { if len(flag.NoOptDefVal) == 0 {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
_, err := fmt.Fprintf(w, format, name) buf.WriteString(fmt.Sprintf(format, name))
return err
} }
func writeFlags(cmd *Command, w io.Writer) error { func writeFlags(buf *bytes.Buffer, cmd *Command) {
_, err := fmt.Fprintf(w, ` flags=() buf.WriteString(` flags=()
two_word_flags=() two_word_flags=()
local_nonpersistent_flags=() local_nonpersistent_flags=()
flags_with_completion=() flags_with_completion=()
flags_completion=() flags_completion=()
`) `)
if err != nil {
return err
}
localNonPersistentFlags := cmd.LocalNonPersistentFlags() localNonPersistentFlags := cmd.LocalNonPersistentFlags()
var visitErr error
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
} }
if err := writeFlag(flag, w); err != nil { writeFlag(buf, flag)
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
if err := writeShortFlag(flag, w); err != nil { writeShortFlag(buf, flag)
visitErr = err
return
}
} }
if localNonPersistentFlags.Lookup(flag.Name) != nil { if localNonPersistentFlags.Lookup(flag.Name) != nil {
if err := writeLocalNonPersistentFlag(flag, w); err != nil { writeLocalNonPersistentFlag(buf, flag)
visitErr = err
return
}
} }
}) })
if visitErr != nil {
return visitErr
}
cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
} }
if err := writeFlag(flag, w); err != nil { writeFlag(buf, flag)
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
if err := writeShortFlag(flag, w); err != nil { writeShortFlag(buf, flag)
visitErr = err
return
}
} }
}) })
if visitErr != nil {
return visitErr buf.WriteString("\n")
} }
_, err = fmt.Fprintf(w, "\n") func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
return err buf.WriteString(" must_have_one_flag=()\n")
}
func writeRequiredFlag(cmd *Command, w io.Writer) error {
if _, err := fmt.Fprintf(w, " must_have_one_flag=()\n"); err != nil {
return err
}
flags := cmd.NonInheritedFlags() flags := cmd.NonInheritedFlags()
var visitErr error
flags.VisitAll(func(flag *pflag.Flag) { flags.VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
@ -481,108 +409,69 @@ func writeRequiredFlag(cmd *Command, w io.Writer) error {
switch key { switch key {
case BashCompOneRequiredFlag: case BashCompOneRequiredFlag:
format := " must_have_one_flag+=(\"--%s" format := " must_have_one_flag+=(\"--%s"
b := (flag.Value.Type() == "bool") if flag.Value.Type() != "bool" {
if !b {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
if _, err := fmt.Fprintf(w, format, flag.Name); err != nil { buf.WriteString(fmt.Sprintf(format, flag.Name))
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
if _, err := fmt.Fprintf(w, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand); err != nil { buf.WriteString(fmt.Sprintf(" must_have_one_flag+=(\"-%s\")\n", flag.Shorthand))
visitErr = err
return
}
} }
} }
} }
}) })
return visitErr
} }
func writeRequiredNouns(cmd *Command, w io.Writer) error { func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
if _, err := fmt.Fprintf(w, " must_have_one_noun=()\n"); err != nil { buf.WriteString(" must_have_one_noun=()\n")
return err
}
sort.Sort(sort.StringSlice(cmd.ValidArgs)) sort.Sort(sort.StringSlice(cmd.ValidArgs))
for _, value := range cmd.ValidArgs { for _, value := range cmd.ValidArgs {
if _, err := fmt.Fprintf(w, " must_have_one_noun+=(%q)\n", value); err != nil { buf.WriteString(fmt.Sprintf(" must_have_one_noun+=(%q)\n", value))
return err
} }
} }
return nil
}
func writeArgAliases(cmd *Command, w io.Writer) error { func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
if _, err := fmt.Fprintf(w, " noun_aliases=()\n"); err != nil { buf.WriteString(" noun_aliases=()\n")
return err
}
sort.Sort(sort.StringSlice(cmd.ArgAliases)) sort.Sort(sort.StringSlice(cmd.ArgAliases))
for _, value := range cmd.ArgAliases { for _, value := range cmd.ArgAliases {
if _, err := fmt.Fprintf(w, " noun_aliases+=(%q)\n", value); err != nil { buf.WriteString(fmt.Sprintf(" noun_aliases+=(%q)\n", value))
return err
} }
} }
return nil
}
func gen(cmd *Command, w io.Writer) error { func gen(buf *bytes.Buffer, cmd *Command) {
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
if err := gen(c, w); err != nil { gen(buf, c)
return err
}
} }
commandName := cmd.CommandPath() commandName := cmd.CommandPath()
commandName = strings.Replace(commandName, " ", "_", -1) commandName = strings.Replace(commandName, " ", "_", -1)
commandName = strings.Replace(commandName, ":", "__", -1) commandName = strings.Replace(commandName, ":", "__", -1)
if _, err := fmt.Fprintf(w, "_%s()\n{\n", commandName); err != nil { buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName))
return err buf.WriteString(fmt.Sprintf(" last_command=%q\n", commandName))
} writeCommands(buf, cmd)
if _, err := fmt.Fprintf(w, " last_command=%q\n", commandName); err != nil { writeFlags(buf, cmd)
return err writeRequiredFlag(buf, cmd)
} writeRequiredNouns(buf, cmd)
if err := writeCommands(cmd, w); err != nil { writeArgAliases(buf, cmd)
return err buf.WriteString("}\n\n")
}
if err := writeFlags(cmd, w); err != nil {
return err
}
if err := writeRequiredFlag(cmd, w); err != nil {
return err
}
if err := writeRequiredNouns(cmd, w); err != nil {
return err
}
if err := writeArgAliases(cmd, w); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "}\n\n"); err != nil {
return err
}
return nil
} }
// GenBashCompletion generates bash completion file and writes to the passed writer. // GenBashCompletion generates bash completion file and writes to the passed writer.
func (cmd *Command) GenBashCompletion(w io.Writer) error { func (cmd *Command) GenBashCompletion(w io.Writer) error {
if err := preamble(w, cmd.Name()); err != nil { buf := new(bytes.Buffer)
return err writePreamble(buf, cmd.Name())
}
if len(cmd.BashCompletionFunction) > 0 { if len(cmd.BashCompletionFunction) > 0 {
if _, err := fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction); err != nil { buf.WriteString(cmd.BashCompletionFunction + "\n")
}
gen(buf, cmd)
writePostscript(buf, cmd.Name())
_, err := buf.WriteTo(w)
return err return err
} }
}
if err := gen(cmd, w); err != nil {
return err
}
return postscript(w, cmd.Name())
}
func nonCompletableFlag(flag *pflag.Flag) bool { func nonCompletableFlag(flag *pflag.Flag) bool {
return flag.Hidden || len(flag.Deprecated) > 0 return flag.Hidden || len(flag.Deprecated) > 0

View file

@ -2,16 +2,12 @@ package cobra
import ( import (
"bytes" "bytes"
"fmt"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"testing" "testing"
) )
var _ = fmt.Println
var _ = os.Stderr
func checkOmit(t *testing.T, found, unexpected string) { func checkOmit(t *testing.T, found, unexpected string) {
if strings.Contains(found, unexpected) { if strings.Contains(found, unexpected) {
t.Errorf("Unexpected response.\nGot: %q\nBut should not have!\n", unexpected) t.Errorf("Unexpected response.\nGot: %q\nBut should not have!\n", unexpected)