diff --git a/command.go b/command.go index a71948b4..5274082d 100644 --- a/command.go +++ b/command.go @@ -66,14 +66,24 @@ type Command struct { // All functions get the same args, the arguments after the command name // PersistentPreRun: children of this command will inherit and execute PersistentPreRun func(cmd *Command, args []string) + // PersistentPreRunE: PersistentPreRun but returns an error + PersistentPreRunE func(cmd *Command, args []string) error // PreRun: children of this command will not inherit. PreRun func(cmd *Command, args []string) + // PreRunE: PreRun but returns an error + PreRunE func(cmd *Command, args []string) error // Run: Typically the actual work function. Most commands will only implement this Run func(cmd *Command, args []string) + // RunE: Run but returns an error + RunE func(cmd *Command, args []string) error // PostRun: run after the Run command. PostRun func(cmd *Command, args []string) + // PostRunE: PostRun but returns an error + PostRunE func(cmd *Command, args []string) error // PersistentPostRun: children of this command will inherit and execute after PostRun PersistentPostRun func(cmd *Command, args []string) + // PersistentPostRunE: PersistentPostRun but returns an error + PersistentPostRunE func(cmd *Command, args []string) error // Commands is the list of commands supported by this program. commands []*Command // Parent Command for this command @@ -474,22 +484,45 @@ func (c *Command) execute(a []string) (err error) { argWoFlags := c.Flags().Args() for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRun != nil { + if p.PersistentPreRunE != nil { + if err := p.PersistentPostRunE(c, argWoFlags); err != nil { + return err + } + break + } else if p.PersistentPreRun != nil { p.PersistentPreRun(c, argWoFlags) break } } - if c.PreRun != nil { + if c.PreRunE != nil { + if err := c.PreRunE(c, argWoFlags); err != nil { + return err + } + } else if c.PreRun != nil { c.PreRun(c, argWoFlags) } - c.Run(c, argWoFlags) - - if c.PostRun != nil { + if c.RunE != nil { + if err := c.RunE(c, argWoFlags); err != nil { + return err + } + } else { + c.Run(c, argWoFlags) + } + if c.PostRunE != nil { + if err := c.PostRunE(c, argWoFlags); err != nil { + return err + } + } else if c.PostRun != nil { c.PostRun(c, argWoFlags) } for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRun != nil { + if p.PersistentPostRunE != nil { + if err := p.PersistentPostRunE(c, argWoFlags); err != nil { + return err + } + break + } else if p.PersistentPostRun != nil { p.PersistentPostRun(c, argWoFlags) break }