spf13--cobra/zsh_completions.go

163 lines
3.8 KiB
Go
Raw Normal View History

2018-06-27 16:10:57 +00:00
package cobra
import (
"bytes"
"fmt"
2018-06-27 15:53:11 +00:00
"io"
"os"
"strings"
2018-06-27 16:10:57 +00:00
flag "github.com/spf13/pflag"
)
// GenZshCompletionFile generates zsh completion file.
func (c *Command) GenZshCompletionFile(filename string) error {
outFile, err := os.Create(filename)
if err != nil {
return err
}
defer outFile.Close()
return c.GenZshCompletion(outFile)
}
2018-06-27 17:46:06 +00:00
func argName(cmd *Command) string {
for cmd.HasParent() {
cmd = cmd.Parent()
}
name := fmt.Sprintf("%s_cmd_args", cmd.Name())
return strings.Replace(name, "-", "_",-1)
}
// GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (c *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer)
writeHeader(buf, c)
maxDepth := maxDepth(c)
2018-06-27 17:46:06 +00:00
fmt.Fprintf(buf, "_%s() {\n", c.Name())
writeLevelMapping(buf, maxDepth, c)
writeLevelCases(buf, maxDepth, c)
2018-06-27 17:46:06 +00:00
fmt.Fprintf(buf, "}\n_%s \"$@\"\n", c.Name())
_, err := buf.WriteTo(w)
return err
}
func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}
func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
}
maxDepthSub := 0
for _, s := range c.Commands() {
subDepth := maxDepth(s)
if subDepth > maxDepthSub {
maxDepthSub = subDepth
}
}
return 1 + maxDepthSub
}
2018-06-27 17:46:06 +00:00
func writeLevelMapping(w io.Writer, numLevels int, root *Command) {
fmt.Fprintln(w, `local context curcontext="$curcontext" state line`)
fmt.Fprintln(w, `typeset -A opt_args`)
2018-06-27 15:48:09 +00:00
fmt.Fprintln(w, `_arguments -C \`)
for i := 1; i <= numLevels; i++ {
2018-06-27 17:46:06 +00:00
fmt.Fprintf(w, " '%d: :->level%d' \\\n", i, i)
}
2018-06-27 17:46:06 +00:00
fmt.Fprintf(w, " $%s \\\n", argName(root))
fmt.Fprintln(w, ` '*: :_files'`)
}
func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
for i := 1; i <= maxDepth; i++ {
writeLevel(w, root, i)
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
2018-06-27 15:48:09 +00:00
fmt.Fprintln(w, "esac")
}
2018-06-27 17:46:06 +00:00
func writeLevel(w io.Writer, root *Command, l int) {
fmt.Fprintf(w, " level%d)\n", l)
fmt.Fprintf(w, " case $words[%d] in\n", l)
for _, c := range filterByLevel(root, l) {
2018-06-27 15:48:09 +00:00
writeCommandArgsBlock(w, c)
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
2018-06-27 15:48:09 +00:00
fmt.Fprintln(w, " esac")
fmt.Fprintln(w, " ;;")
}
2018-06-27 15:48:09 +00:00
func writeCommandArgsBlock(w io.Writer, c *Command) {
names := commandNames(c)
flags := commandFlags(c)
if len(names) > 0 || len(flags) > 0 {
fmt.Fprintf(w, " %s)\n", c.Name())
defer fmt.Fprintln(w, " ;;")
}
2018-06-27 15:48:09 +00:00
if len(flags) > 0 {
2018-06-27 17:46:06 +00:00
fmt.Fprintf(w, " %s=(\n", argName(c))
2018-06-27 15:48:09 +00:00
for _, flag := range flags {
fmt.Fprintf(w, " %s\n", flag)
}
2018-06-27 17:46:06 +00:00
fmt.Fprintln(w, " )")
2018-06-27 15:48:09 +00:00
}
if len(names) > 0 {
fmt.Fprintf(w, " _values 'command' '%s'\n", strings.Join(names, "' '"))
}
}
2018-06-27 15:48:09 +00:00
func filterByLevel(c *Command, l int) []*Command {
commands := []*Command{c}
for i := 1; i < l; i++ {
var nextLevel []*Command
for _, c := range commands {
if c.HasSubCommands() {
nextLevel = append(nextLevel, c.Commands()...)
}
}
2018-06-27 15:48:09 +00:00
commands = nextLevel
}
2018-06-27 15:48:09 +00:00
return commands
}
2018-06-27 15:48:09 +00:00
func commandNames(command *Command) []string {
commands := command.Commands()
ns := make([]string, len(commands))
for i, c := range commands {
2018-06-27 16:10:57 +00:00
commandMsg := c.Name()
if len(c.Short) > 0 {
commandMsg += fmt.Sprintf("[%s]", c.Short)
}
ns[i] = commandMsg
}
return ns
}
2018-06-27 15:48:09 +00:00
func commandFlags(command *Command) []string {
flags := command.Flags()
2018-06-27 15:48:09 +00:00
ns := make([]string, 0)
2018-06-27 16:10:57 +00:00
flags.VisitAll(func(flag *flag.Flag) {
var flagMsg string
2018-06-27 15:48:09 +00:00
if len(flag.Shorthand) > 0 {
2018-06-27 16:10:57 +00:00
flagMsg = fmt.Sprintf("{-%s,--%s}", flag.Shorthand, flag.Name)
2018-06-27 15:48:09 +00:00
} else {
2018-06-27 16:10:57 +00:00
flagMsg = fmt.Sprintf("--%s", flag.Name)
2018-06-27 15:48:09 +00:00
}
2018-06-27 16:10:57 +00:00
if len(flag.Usage) > 0 {
flagMsg += fmt.Sprintf("'[%s]'", flag.Usage)
}
ns = append(ns, flagMsg)
2018-06-27 15:48:09 +00:00
})
return ns
}