diff --git a/zsh_completions.go b/zsh_completions.go index 9e8ff087..4651cf2f 100644 --- a/zsh_completions.go +++ b/zsh_completions.go @@ -5,8 +5,6 @@ import ( "io" "bytes" "fmt" - "github.com/spf13/cobra" -"github.com/spf13/pflag" "strings" ) @@ -53,7 +51,9 @@ func maxDepth(c *Command) int { } func writeLevelMapping(w io.Writer, numLevels int) { - fmt.Fprintln(w, `_arguments \`) + fmt.Fprintln(w, "local -a rails_options") + fmt.Fprintln(w, `_arguments -C \`) + fmt.Fprintln(w, ` $jamf_pro_options \`) for i := 1; i <= numLevels; i++ { fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) fmt.Fprintln(w) @@ -64,64 +64,64 @@ func writeLevelMapping(w io.Writer, numLevels int) { func writeLevelCases(w io.Writer, maxDepth int, root *Command) { fmt.Fprintln(w, "case $state in") - defer fmt.Fprintln(w, "esac") - for i := 1; i <= maxDepth; i++ { - fmt.Fprintf(w, " level%d)\n", i) writeLevel(w, root, i) - fmt.Fprintln(w, " ;;") } fmt.Fprintln(w, " *)") fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " ;;") + fmt.Fprintln(w, "esac") } -func writeLevel(w io.Writer, root *Command, i int) { - fmt.Fprintf(w, " case $words[%d] in\n", i) - defer fmt.Fprintln(w, " esac") - - commands := filterByLevel(root, i) - byParent := groupByParent(commands) - - for p, c := range byParent { - names := names(c) - //flags := flags(p) - fmt.Fprintf(w, " %s)\n", p.Name()) - fmt.Fprintf(w, " _values '%s command' ", p.Name()) - fmt.Fprintf(w, "'%s'\n", strings.Join(names, "' '")) - fmt.Fprintln(w, " ;;") +func writeLevel(w io.Writer, root *Command, level int) { + fmt.Fprintf(w, " level%d)\n", level) + fmt.Fprintf(w, " case $words[%d] in\n", level) + for _, c := range filterByLevel(root, level) { + writeCommandArgsBlock(w, c) } fmt.Fprintln(w, " *)") fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " ;;") + fmt.Fprintln(w, " esac") + fmt.Fprintln(w, " ;;") +} +func writeCommandArgsBlock(w io.Writer, c *Command) { + names := commandNames(c) + flags := commandFlags(c) + if len(names) > 0 || len(flags) > 0 { + fmt.Fprintf(w, " %s)\n", c.Name()) + defer fmt.Fprintln(w, " ;;") + } + if len(flags) > 0 { + fmt.Fprintln(w, " jamf_pro_options=(") + for _, flag := range flags { + fmt.Fprintf(w, " %s\n", flag) + } + fmt.Fprintln(w, "\n )") + } + if len(names) > 0 { + fmt.Fprintf(w, " _values 'command' '%s'\n", strings.Join(names, "' '")) + } } func filterByLevel(c *Command, l int) []*Command { - cs := make([]*Command, 0) - if l == 0 { - cs = append(cs, c) - return cs - } - for _, s := range c.Commands() { - cs = append(cs, filterByLevel(s, l-1)...) - } - return cs -} - -func groupByParent(commands []*Command) map[*Command][]*Command { - m := make(map[*Command][]*Command) - for _, c := range commands { - parent := c.Parent() - if parent == nil { - continue + commands := []*Command{c} + for i := 1; i < l; i++ { + var nextLevel []*Command + for _, c := range commands { + if c.HasSubCommands() { + nextLevel = append(nextLevel, c.Commands()...) + } } - m[parent] = append(m[parent], c) + commands = nextLevel } - return m + + return commands } -func names(commands []*Command) []string { +func commandNames(command *Command) []string { + commands := command.Commands() ns := make([]string, len(commands)) for i, c := range commands { ns[i] = fmt.Sprintf("%s[%s]", c.Name(), c.Short) @@ -129,15 +129,15 @@ func names(commands []*Command) []string { return ns } -func flags(command *Command) []string { +func commandFlags(command *Command) []string { flags := command.Flags() - ns := make([]string, flags.NArg()) - if flags.NArg() > 0 { - i := 0 - flags.VisitAll(func(f *pflag.Flag) { - ns[i] = fmt.Sprintf("%s[%s]", f.Name, f.Usage) - i += 1 - }) - } + ns := make([]string, 0) + flags.VisitAll(func(flag *pflag.Flag) { + if len(flag.Shorthand) > 0 { + ns = append(ns, fmt.Sprintf("{-%s,--%s}'[%s]'", flag.Shorthand, flag.Name, flag.Usage)) + } else { + ns = append(ns, fmt.Sprintf("--%s'[%s]'", flag.Name, flag.Usage)) + } + }) return ns }