Allow flag completion as well

This commit is contained in:
Brandon Roehl 2018-06-27 10:48:09 -05:00
parent 14efddf125
commit cbe65494b7

View file

@ -5,8 +5,6 @@ import (
"io" "io"
"bytes" "bytes"
"fmt" "fmt"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"strings" "strings"
) )
@ -53,7 +51,9 @@ func maxDepth(c *Command) int {
} }
func writeLevelMapping(w io.Writer, numLevels int) { func writeLevelMapping(w io.Writer, numLevels int) {
fmt.Fprintln(w, `_arguments \`) fmt.Fprintln(w, "local -a rails_options")
fmt.Fprintln(w, `_arguments -C \`)
fmt.Fprintln(w, ` $jamf_pro_options \`)
for i := 1; i <= numLevels; i++ { for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
fmt.Fprintln(w) fmt.Fprintln(w)
@ -64,64 +64,64 @@ func writeLevelMapping(w io.Writer, numLevels int) {
func writeLevelCases(w io.Writer, maxDepth int, root *Command) { func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in") fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")
for i := 1; i <= maxDepth; i++ { for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i) writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
} }
fmt.Fprintln(w, " *)") fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;") fmt.Fprintln(w, " ;;")
fmt.Fprintln(w, "esac")
} }
func writeLevel(w io.Writer, root *Command, i int) { func writeLevel(w io.Writer, root *Command, level int) {
fmt.Fprintf(w, " case $words[%d] in\n", i) fmt.Fprintf(w, " level%d)\n", level)
defer fmt.Fprintln(w, " esac") fmt.Fprintf(w, " case $words[%d] in\n", level)
for _, c := range filterByLevel(root, level) {
commands := filterByLevel(root, i) writeCommandArgsBlock(w, c)
byParent := groupByParent(commands)
for p, c := range byParent {
names := names(c)
//flags := flags(p)
fmt.Fprintf(w, " %s)\n", p.Name())
fmt.Fprintf(w, " _values '%s command' ", p.Name())
fmt.Fprintf(w, "'%s'\n", strings.Join(names, "' '"))
fmt.Fprintln(w, " ;;")
} }
fmt.Fprintln(w, " *)") fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;") fmt.Fprintln(w, " ;;")
fmt.Fprintln(w, " esac")
fmt.Fprintln(w, " ;;")
}
func writeCommandArgsBlock(w io.Writer, c *Command) {
names := commandNames(c)
flags := commandFlags(c)
if len(names) > 0 || len(flags) > 0 {
fmt.Fprintf(w, " %s)\n", c.Name())
defer fmt.Fprintln(w, " ;;")
}
if len(flags) > 0 {
fmt.Fprintln(w, " jamf_pro_options=(")
for _, flag := range flags {
fmt.Fprintf(w, " %s\n", flag)
}
fmt.Fprintln(w, "\n )")
}
if len(names) > 0 {
fmt.Fprintf(w, " _values 'command' '%s'\n", strings.Join(names, "' '"))
}
} }
func filterByLevel(c *Command, l int) []*Command { func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0) commands := []*Command{c}
if l == 0 { for i := 1; i < l; i++ {
cs = append(cs, c) var nextLevel []*Command
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}
func groupByParent(commands []*Command) map[*Command][]*Command {
m := make(map[*Command][]*Command)
for _, c := range commands { for _, c := range commands {
parent := c.Parent() if c.HasSubCommands() {
if parent == nil { nextLevel = append(nextLevel, c.Commands()...)
continue
} }
m[parent] = append(m[parent], c)
} }
return m commands = nextLevel
} }
func names(commands []*Command) []string { return commands
}
func commandNames(command *Command) []string {
commands := command.Commands()
ns := make([]string, len(commands)) ns := make([]string, len(commands))
for i, c := range commands { for i, c := range commands {
ns[i] = fmt.Sprintf("%s[%s]", c.Name(), c.Short) ns[i] = fmt.Sprintf("%s[%s]", c.Name(), c.Short)
@ -129,15 +129,15 @@ func names(commands []*Command) []string {
return ns return ns
} }
func flags(command *Command) []string { func commandFlags(command *Command) []string {
flags := command.Flags() flags := command.Flags()
ns := make([]string, flags.NArg()) ns := make([]string, 0)
if flags.NArg() > 0 { flags.VisitAll(func(flag *pflag.Flag) {
i := 0 if len(flag.Shorthand) > 0 {
flags.VisitAll(func(f *pflag.Flag) { ns = append(ns, fmt.Sprintf("{-%s,--%s}'[%s]'", flag.Shorthand, flag.Name, flag.Usage))
ns[i] = fmt.Sprintf("%s[%s]", f.Name, f.Usage) } else {
i += 1 ns = append(ns, fmt.Sprintf("--%s'[%s]'", flag.Name, flag.Usage))
})
} }
})
return ns return ns
} }