mirror of
https://github.com/spf13/cobra
synced 2024-11-24 14:47:12 +00:00
Add basic zsh completion (command hierarchy only)
Partially fixes #107 See PR #497
This commit is contained in:
parent
9e024b655b
commit
d7ba19510d
2 changed files with 202 additions and 0 deletions
114
zsh_completions.go
Normal file
114
zsh_completions.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
package cobra
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenZshCompletion generates a zsh completion file and writes to the passed writer.
|
||||||
|
func (cmd *Command) GenZshCompletion(w io.Writer) error {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
writeHeader(buf, cmd)
|
||||||
|
maxDepth := maxDepth(cmd)
|
||||||
|
writeLevelMapping(buf, maxDepth)
|
||||||
|
writeLevelCases(buf, maxDepth, cmd)
|
||||||
|
|
||||||
|
_, err := buf.WriteTo(w)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeHeader(w io.Writer, cmd *Command) {
|
||||||
|
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxDepth(c *Command) int {
|
||||||
|
if len(c.Commands()) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
maxDepthSub := 0
|
||||||
|
for _, s := range c.Commands() {
|
||||||
|
subDepth := maxDepth(s)
|
||||||
|
if subDepth > maxDepthSub {
|
||||||
|
maxDepthSub = subDepth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1 + maxDepthSub
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLevelMapping(w io.Writer, numLevels int) {
|
||||||
|
fmt.Fprintln(w, `_arguments \`)
|
||||||
|
for i := 1; i <= numLevels; i++ {
|
||||||
|
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
|
||||||
|
fmt.Fprintln(w)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
|
||||||
|
fmt.Fprintln(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
|
||||||
|
fmt.Fprintln(w, "case $state in")
|
||||||
|
defer fmt.Fprintln(w, "esac")
|
||||||
|
|
||||||
|
for i := 1; i <= maxDepth; i++ {
|
||||||
|
fmt.Fprintf(w, " level%d)\n", i)
|
||||||
|
writeLevel(w, root, i)
|
||||||
|
fmt.Fprintln(w, " ;;")
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w, " *)")
|
||||||
|
fmt.Fprintln(w, " _arguments '*: :_files'")
|
||||||
|
fmt.Fprintln(w, " ;;")
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLevel(w io.Writer, root *Command, i int) {
|
||||||
|
fmt.Fprintf(w, " case $words[%d] in\n", i)
|
||||||
|
defer fmt.Fprintln(w, " esac")
|
||||||
|
|
||||||
|
commands := filterByLevel(root, i)
|
||||||
|
byParent := groupByParent(commands)
|
||||||
|
|
||||||
|
for p, c := range byParent {
|
||||||
|
names := names(c)
|
||||||
|
fmt.Fprintf(w, " %s)\n", p)
|
||||||
|
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
|
||||||
|
fmt.Fprintln(w, " ;;")
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w, " *)")
|
||||||
|
fmt.Fprintln(w, " _arguments '*: :_files'")
|
||||||
|
fmt.Fprintln(w, " ;;")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterByLevel(c *Command, l int) []*Command {
|
||||||
|
cs := make([]*Command, 0)
|
||||||
|
if l == 0 {
|
||||||
|
cs = append(cs, c)
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
for _, s := range c.Commands() {
|
||||||
|
cs = append(cs, filterByLevel(s, l-1)...)
|
||||||
|
}
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupByParent(commands []*Command) map[string][]*Command {
|
||||||
|
m := make(map[string][]*Command)
|
||||||
|
for _, c := range commands {
|
||||||
|
parent := c.Parent()
|
||||||
|
if parent == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[parent.Name()] = append(m[parent.Name()], c)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func names(commands []*Command) []string {
|
||||||
|
ns := make([]string, len(commands))
|
||||||
|
for i, c := range commands {
|
||||||
|
ns[i] = c.Name()
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
88
zsh_completions_test.go
Normal file
88
zsh_completions_test.go
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
package cobra
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestZshCompletion(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
root *Command
|
||||||
|
expectedExpressions []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trivial",
|
||||||
|
root: &Command{Use: "trivialapp"},
|
||||||
|
expectedExpressions: []string{"#compdef trivial"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linear",
|
||||||
|
root: func() *Command {
|
||||||
|
r := &Command{Use: "linear"}
|
||||||
|
|
||||||
|
sub1 := &Command{Use: "sub1"}
|
||||||
|
r.AddCommand(sub1)
|
||||||
|
|
||||||
|
sub2 := &Command{Use: "sub2"}
|
||||||
|
sub1.AddCommand(sub2)
|
||||||
|
|
||||||
|
sub3 := &Command{Use: "sub3"}
|
||||||
|
sub2.AddCommand(sub3)
|
||||||
|
return r
|
||||||
|
}(),
|
||||||
|
expectedExpressions: []string{"sub1", "sub2", "sub3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "flat",
|
||||||
|
root: func() *Command {
|
||||||
|
r := &Command{Use: "flat"}
|
||||||
|
r.AddCommand(&Command{Use: "c1"})
|
||||||
|
r.AddCommand(&Command{Use: "c2"})
|
||||||
|
return r
|
||||||
|
}(),
|
||||||
|
expectedExpressions: []string{"(c1 c2)"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tree",
|
||||||
|
root: func() *Command {
|
||||||
|
r := &Command{Use: "tree"}
|
||||||
|
|
||||||
|
sub1 := &Command{Use: "sub1"}
|
||||||
|
r.AddCommand(sub1)
|
||||||
|
|
||||||
|
sub11 := &Command{Use: "sub11"}
|
||||||
|
sub12 := &Command{Use: "sub12"}
|
||||||
|
|
||||||
|
sub1.AddCommand(sub11)
|
||||||
|
sub1.AddCommand(sub12)
|
||||||
|
|
||||||
|
sub2 := &Command{Use: "sub2"}
|
||||||
|
r.AddCommand(sub2)
|
||||||
|
|
||||||
|
sub21 := &Command{Use: "sub21"}
|
||||||
|
sub22 := &Command{Use: "sub22"}
|
||||||
|
|
||||||
|
sub2.AddCommand(sub21)
|
||||||
|
sub2.AddCommand(sub22)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}(),
|
||||||
|
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
tc.root.GenZshCompletion(buf)
|
||||||
|
completion := buf.String()
|
||||||
|
for _, expectedExpression := range tc.expectedExpressions {
|
||||||
|
if !strings.Contains(completion, expectedExpression) {
|
||||||
|
t.Errorf("expected completion to contain '%v' somewhere; got '%v'", expectedExpression, completion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue