From af165be453a6c4a9e09603e121dd18badd3f7322 Mon Sep 17 00:00:00 2001 From: Victor Kareh Date: Tue, 2 Mar 2021 11:50:31 -0500 Subject: [PATCH] command: Allow overriding of flag parse function To allow setting a user-defined flag parser, we add a new SetFlagParseFunc to Command. This function, when set, will be called instead of pflag.Parse. --- command.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/command.go b/command.go index ce94d40f..110855e1 100644 --- a/command.go +++ b/command.go @@ -145,6 +145,9 @@ type Command struct { usageFunc func(*Command) error // usageTemplate is usage template defined by user. usageTemplate string + // flagParseFunc is func defined by user and it's called when the command is + // parsing the flags. + flagParseFunc func(*Command, []string) error // flagErrorFunc is func defined by user and it's called when the parsing of // flags returns an error. flagErrorFunc func(*Command, error) error @@ -271,6 +274,11 @@ func (c *Command) SetUsageTemplate(s string) { c.usageTemplate = s } +// SetFlagParseFunc sets a function to parse flags. +func (c *Command) SetFlagParseFunc(f func(*Command, []string) error) { + c.flagParseFunc = f +} + // SetFlagErrorFunc sets a function to generate an error when flag parsing // fails. func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) { @@ -432,6 +440,22 @@ func (c *Command) UsageString() string { return bb.String() } +// FlagParseFunc returns either the function set by SetFlagParseFunc for this +// command or a parent, or it returns a function which calls the original +// flag parse function. +func (c *Command) FlagParseFunc() (f func(*Command, []string) error) { + if c.flagParseFunc != nil { + return c.flagParseFunc + } + + if c.HasParent() { + return c.parent.FlagParseFunc() + } + return func(c *Command, args []string) error { + return c.Flags().Parse(args) + } +} + // FlagErrorFunc returns either the function set by SetFlagErrorFunc for this // command or a parent, or it returns a function which returns the original // error. @@ -1626,7 +1650,7 @@ func (c *Command) ParseFlags(args []string) error { // do it here after merging all flags and just before parse c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) - err := c.Flags().Parse(args) + err := c.FlagParseFunc()(c, args) // Print warnings if they occurred (e.g. deprecated flag messages). if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { c.Print(c.flagErrorBuf.String())