From fe9ade5260b2534e7eab30bdf472a96229f8692f Mon Sep 17 00:00:00 2001 From: Gauri Prasad Date: Mon, 11 Apr 2022 11:11:41 -0700 Subject: [PATCH] Added an option to set custom FlagErrorHandling --- command.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/command.go b/command.go index ee5365bc..26341359 100644 --- a/command.go +++ b/command.go @@ -31,6 +31,9 @@ import ( // FParseErrWhitelist configures Flag parse errors to be ignored type FParseErrWhitelist flag.ParseErrorsWhitelist +// FErrorHandling defines how to handle flag parsing errors +type FErrorHandling flag.ErrorHandling + // Command is just that, a command for your application. // E.g. 'go run ...' - 'run' is the command. Cobra requires // you to define the usage and description as part of your command @@ -222,6 +225,9 @@ type Command struct { // SuggestionsMinimumDistance defines minimum levenshtein distance to display suggestions. // Must be > 0. SuggestionsMinimumDistance int + + // FlagErrorHandling defines how to handle flag parsing errors. Defaults to flag.ContinueOnError. + FlagErrorHandling FErrorHandling } // Context returns underlying command context. If command was executed @@ -1460,7 +1466,7 @@ func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) f // to this command (local and persistent declared here and by all parents). func (c *Command) Flags() *flag.FlagSet { if c.flags == nil { - c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1474,7 +1480,7 @@ func (c *Command) Flags() *flag.FlagSet { func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { persistentFlags := c.PersistentFlags() - out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + out := flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.LocalFlags().VisitAll(func(f *flag.Flag) { if persistentFlags.Lookup(f.Name) == nil { out.AddFlag(f) @@ -1488,7 +1494,7 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() if c.lflags == nil { - c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.lflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1514,7 +1520,7 @@ func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() if c.iflags == nil { - c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.iflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1542,7 +1548,7 @@ func (c *Command) NonInheritedFlags() *flag.FlagSet { // PersistentFlags returns the persistent FlagSet specifically set in the current command. func (c *Command) PersistentFlags() *flag.FlagSet { if c.pflags == nil { - c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1555,9 +1561,9 @@ func (c *Command) PersistentFlags() *flag.FlagSet { func (c *Command) ResetFlags() { c.flagErrorBuf = new(bytes.Buffer) c.flagErrorBuf.Reset() - c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.flags.SetOutput(c.flagErrorBuf) - c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.pflags.SetOutput(c.flagErrorBuf) c.lflags = nil @@ -1674,7 +1680,7 @@ func (c *Command) mergePersistentFlags() { // If c.parentsPflags == nil, it makes new. func (c *Command) updateParentsPflags() { if c.parentsPflags == nil { - c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.parentsPflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.parentsPflags.SetOutput(c.flagErrorBuf) c.parentsPflags.SortFlags = false }