// Copyright © 2013 Steve Francia . // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Commands similar to git, go tools and other modern CLI tools // inspired by go, go-Commander, gh and subcommand package cobra import ( "bytes" "fmt" flag "github.com/spf13/pflag" "io" "strings" ) // Command is just that, a command for your application. // eg. 'go run' ... 'run' is the command. Cobra requires // you to define the usage and description as part of your command // definition to ensure usability. type Command struct { // Name is the command name, usually the executable's name. name string // The one-line usage message. Use string // The short description shown in the 'help' output. Short string // The long message shown in the 'help ' output. Long string // Set of flags specific to this command. flags *flag.FlagSet // Set of flags children commands will inherit pflags *flag.FlagSet // Run runs the command. // The args are the arguments after the command name. Run func(cmd *Command, args []string) // Commands is the list of commands supported by this Commander program. commands []*Command // Parent Command for this command parent *Command cmdr *Commander flagErrorBuf *bytes.Buffer } // find the target command given the args and command tree // Meant to be run on the highest node. Only searches down. func (c *Command) Find(args []string) (cmd *Command, a []string, err error) { if c == nil { return nil, nil, fmt.Errorf("Called find() on a nil Command") } validSubCommand := false if len(args) > 0 && c.HasSubCommands() { for _, cmd := range c.commands { if cmd.Name() == args[0] { validSubCommand = true return cmd.Find(args[1:]) } } } if !validSubCommand && c.Runnable() { return c, args, nil } return nil, nil, nil } func (c *Command) Root() *Command { var findRoot func(*Command) *Command findRoot = func(x *Command) *Command { if x.HasParent() { return findRoot(x.parent) } else { return x } } return findRoot(c) } func (c *Command) Commander() *Commander { cmdr := c.Root() if cmdr.cmdr != nil { return cmdr.cmdr } else { panic("commander not found") } } func (c *Command) Out() io.Writer { return c.Commander().out() } // execute the command determined by args and the command tree func (c *Command) execute(args []string) (err error) { err = fmt.Errorf("unknown subcommand %q\nRun 'help' for usage.\n", args[0]) if c == nil { return fmt.Errorf("Called Execute() on a nil Command") } cmd, a, e := c.Find(args) if e == nil { err = cmd.ParseFlags(a) if err != nil { // report flag parsing error c.Println(strings.Split(err.Error(), "\n")[0]) erx := cmd.Usage() if erx != nil { return erx } return err } else { argWoFlags := cmd.Flags().Args() cmd.Run(cmd, argWoFlags) return nil } } err = e return err } // Used for testing func (c *Command) ResetCommands() { c.commands = nil } func (c *Command) Commands() []*Command { return c.commands } // Add one or many commands as children of this func (c *Command) AddCommand(cmds ...*Command) { for i, x := range cmds { if cmds[i] == c { panic("Command can't be a child of itself") } cmds[i].parent = c c.commands = append(c.commands, x) } } // Convenience method to Print to the defined output func (c *Command) Print(i ...interface{}) { c.Commander().PrintOut(i...) } // Convenience method to Println to the defined output func (c *Command) Println(i ...interface{}) { str := fmt.Sprintln(i...) c.Print(str) } // Convenience method to Printf to the defined output func (c *Command) Printf(format string, i ...interface{}) { str := fmt.Sprintf(format, i...) c.Print(str) } // Output the usage for the command // Used when a user provides invalid input // Can be defined by user by overriding Commander.UsageFunc func (c *Command) Usage() error { err := c.Commander().UsageFunc(c) return err } // The full path to this command func (c *Command) CommandPath() string { str := c.Name() x := c for x.HasParent() { str = x.parent.Name() + " " + str x = x.parent } return str } //The full usage for a given command (including parents) func (c *Command) UseLine() string { str := "" if c.HasParent() { str = c.parent.CommandPath() + " " } return str + c.Use } // For use in determining which flags have been assigned to which commands // and which persist func (c *Command) DebugFlags() { c.Println("DebugFlags called on", c.Name()) var debugflags func(*Command) debugflags = func(x *Command) { if x.HasFlags() || x.HasPersistentFlags() { c.Println(x.Name()) } if x.HasFlags() { x.flags.VisitAll(func(f *flag.Flag) { if x.HasPersistentFlags() { if x.persistentFlag(f.Name) == nil { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [L]") } else { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [LP]") } } else { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [L]") } }) } if x.HasPersistentFlags() { x.pflags.VisitAll(func(f *flag.Flag) { if x.HasFlags() { if x.flags.Lookup(f.Name) == nil { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [P]") } } else { c.Println(" -"+f.Shorthand+",", "--"+f.Name, "["+f.DefValue+"]", "", f.Value, " [P]") } }) } c.Println(x.flagErrorBuf) if x.HasSubCommands() { for _, y := range x.commands { debugflags(y) } } } debugflags(c) } // Name returns the command's name: the first word in the use line. func (c *Command) Name() string { if c.name != "" { return c.name } name := c.Use i := strings.Index(name, " ") if i >= 0 { name = name[:i] } return name } // Determine if the command is itself runnable func (c *Command) Runnable() bool { return c.Run != nil } // Determine if the command has children commands func (c *Command) HasSubCommands() bool { return len(c.commands) > 0 } // Determine if the command is a child command func (c *Command) HasParent() bool { return c.parent != nil } // Get the Commands FlagSet func (c *Command) Flags() *flag.FlagSet { if c.flags == nil { c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } c.flags.SetOutput(c.flagErrorBuf) } return c.flags } // Get the Commands Persistent FlagSet func (c *Command) PersistentFlags() *flag.FlagSet { if c.pflags == nil { c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } c.pflags.SetOutput(c.flagErrorBuf) } return c.pflags } // For use in testing func (c *Command) ResetFlags() { c.flagErrorBuf = new(bytes.Buffer) c.flagErrorBuf.Reset() c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.flags.SetOutput(c.flagErrorBuf) c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.pflags.SetOutput(c.flagErrorBuf) } // Does the command contain flags (local not persistent) func (c *Command) HasFlags() bool { return c.Flags().HasFlags() } // Does the command contain persistent flags func (c *Command) HasPersistentFlags() bool { return c.PersistentFlags().HasFlags() } // Climbs up the command tree looking for matching flag func (c *Command) Flag(name string) (flag *flag.Flag) { flag = c.Flags().Lookup(name) if flag == nil { flag = c.persistentFlag(name) } return } // recursively find matching persistent flag func (c *Command) persistentFlag(name string) (flag *flag.Flag) { if c.HasPersistentFlags() { flag = c.PersistentFlags().Lookup(name) } if flag == nil && c.HasParent() { flag = c.parent.persistentFlag(name) } return } // Parses persistent flag tree & local flags func (c *Command) ParseFlags(args []string) (err error) { c.mergePersistentFlags() err = c.Flags().Parse(args) if err != nil { return err } // The upstream library adds spaces to the error // response regardless of success. // Handling it here until fixing upstream if len(strings.TrimSpace(c.flagErrorBuf.String())) > 1 { return fmt.Errorf("%s", c.flagErrorBuf.String()) } return nil } func (c *Command) mergePersistentFlags() { var rmerge func(x *Command) rmerge = func(x *Command) { if x.HasPersistentFlags() { x.PersistentFlags().VisitAll(func(f *flag.Flag) { if c.Flags().Lookup(f.Name) == nil { c.Flags().AddFlag(f) } }) } if x.HasParent() { rmerge(x.parent) } } rmerge(c) } func (c *Command) Parent() *Command { return c.parent }