diff --git a/command.go b/command.go index 15371289..90f98c3f 100644 --- a/command.go +++ b/command.go @@ -1770,7 +1770,7 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) { return } -func (c *Command) parseLongArgs(s string, args []string) (passedArgs, restArgs []string) { +func (c *Command) parseLongArgs(s string, args []string, flags *flag.FlagSet) (passedArgs, restArgs []string) { restArgs = args name := s[2:] if len(name) == 0 { @@ -1780,11 +1780,10 @@ func (c *Command) parseLongArgs(s string, args []string) (passedArgs, restArgs [ split := strings.SplitN(s[2:], "=", 2) name = split[0] - searchedFlag := c.lflags.Lookup(name) - + searchedFlag := flags.Lookup(name) if searchedFlag == nil { - // ignore the flag that is not registered in c.lflags but is registered in c.flags - c.flags.VisitAll(func(f *flag.Flag) { + // ignore the flag that is not registered in passed flags but is registered in c.parentsPflags + c.parentsPflags.VisitAll(func(f *flag.Flag) { if name == f.Name { if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 { // '--flag arg' @@ -1804,16 +1803,16 @@ func (c *Command) parseLongArgs(s string, args []string) (passedArgs, restArgs [ return } -func (c *Command) parseShortArgs(s string, args []string) (passedArgs []string, restArgs []string) { +func (c *Command) parseShortArgs(s string, args []string, flags *flag.FlagSet) (passedArgs []string, restArgs []string) { restArgs = args shorthands := s[1:] shorthand := string(s[1]) - searchedFlag := c.lflags.ShorthandLookup(shorthand) + searchedFlag := flags.ShorthandLookup(shorthand) if searchedFlag == nil { - // ignore the flag that is not registered in c.lflags but is registered in c.flags - c.flags.VisitAll(func(f *flag.Flag) { + // ignore the flag that is not registered in passed flags but is registered in c.parentsPflags + c.parentsPflags.VisitAll(func(f *flag.Flag) { if shorthand == f.Shorthand { if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 { // '-f arg' @@ -1834,7 +1833,7 @@ func (c *Command) parseShortArgs(s string, args []string) (passedArgs []string, return } -func (c *Command) removeParentPersistentArgs(args []string) (newArgs []string) { +func (c *Command) removeParentPersistentArgs(args []string, flags *flag.FlagSet) (newArgs []string) { for len(args) > 0 { s := args[0] args = args[1:] @@ -1845,9 +1844,9 @@ func (c *Command) removeParentPersistentArgs(args []string) (newArgs []string) { var passedArgs, restArgs []string if s[1] == '-' { - passedArgs, restArgs = c.parseLongArgs(s, args) + passedArgs, restArgs = c.parseLongArgs(s, args, flags) } else { - passedArgs, restArgs = c.parseShortArgs(s, args) + passedArgs, restArgs = c.parseShortArgs(s, args, flags) } if len(passedArgs) > 0 { newArgs = append(newArgs, passedArgs...) @@ -1884,12 +1883,25 @@ func (c *Command) ParseFlags(args []string) error { // parse Local Flags c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs c.lflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) - localArgs := c.removeParentPersistentArgs(args) // get only arguments related to c.lflags + localArgs := c.removeParentPersistentArgs(args, c.lflags) // get only arguments related to c.lflags err = c.lflags.Parse(localArgs) // Print warnings if they occurred (e.g. deprecated flag messages). if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { c.Print(c.flagErrorBuf.String()) } + if err != nil { + return err + } + + // parse local non persistent flags + c.LocalNonPersistentFlags() // need to execute LocalNonPersistentFlags() to set the value in c.lnpflags before executing removeParentPersistentArgs + c.lnpflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) + localNonPersistentArgs := c.removeParentPersistentArgs(args, c.lnpflags) + err = c.lnpflags.Parse(localNonPersistentArgs) + // Print warnings if they occurred (e.g. deprecated flag messages). + if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { + c.Print(c.flagErrorBuf.String()) + } return err }