From 200e51835642e161edbba724d809676d6a42040f Mon Sep 17 00:00:00 2001 From: Jun Nishimura Date: Sat, 22 Jul 2023 14:55:03 +0900 Subject: [PATCH] parse lflags --- command.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/command.go b/command.go index a683acd6..8bd403c1 100644 --- a/command.go +++ b/command.go @@ -1767,6 +1767,91 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) { return } +func (c *Command) parseLongArgs(s string, args []string) (passedArgs, restArgs []string) { + restArgs = args + + split := strings.SplitN(s[2:], "=", 2) + name := split[0] + searchedFlag := c.lflags.Lookup(name) + + if searchedFlag == nil { + // a flag not registered in c.lflags should be registered in c.flags + c.flags.VisitAll(func(f *flag.Flag) { + if name == f.Name { + if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 { + // '--flag arg' + restArgs = args[1:] + } + return + } + }) + } + + passedArgs = append(passedArgs, fmt.Sprintf("--%s", name)) + if len(split) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 { + passedArgs = append(passedArgs, args[0]) + restArgs = args[1:] + } + + return +} + +func (c *Command) parseShortArgs(s string, args []string) (passedArgs []string, restArgs []string) { + restArgs = args + + shorthands := s[1:] + shorthand := string(s[1]) + + searchedFlag := c.lflags.ShorthandLookup(shorthand) + if searchedFlag == nil { + // a flag not registered in c.lflags should be registered in c.flags + c.flags.VisitAll(func(f *flag.Flag) { + if shorthand == f.Shorthand { + if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 { + // '-f arg' + restArgs = args[1:] + } + } + }) + return + } + + passedArgs = append(passedArgs, s) + if len(shorthands) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 { + // '-f arg' + passedArgs = append(passedArgs, args[0]) + restArgs = args[1:] + } + + return +} + +func (c *Command) removeParentPersistentArgs(args []string) []string { + var newArgs []string + + for len(args) > 0 { + s := args[0] + args = args[1:] + if s[0] != '-' { + newArgs = append(newArgs, s) + continue + } + + var passedArgs, restArgs []string + if s[1] == '-' { + passedArgs, restArgs = c.parseLongArgs(s, args) + } else { + passedArgs, restArgs = c.parseShortArgs(s, args) + } + if len(passedArgs) > 0 { + newArgs = append(newArgs, passedArgs...) + } + args = restArgs + } + + return newArgs +} + // ParseFlags parses persistent flag tree and local flags. func (c *Command) ParseFlags(args []string) error { if c.DisableFlagParsing { @@ -1790,7 +1875,9 @@ func (c *Command) ParseFlags(args []string) error { } // parse Local Flags - err = c.LocalFlags().Parse(args) + c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs + localArgs := c.removeParentPersistentArgs(args) // get only arguments related to c.lflags + err = c.lflags.Parse(localArgs) // Print warnings if they occurred (e.g. deprecated flag messages). if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { c.Print(c.flagErrorBuf.String())