From 0dacccfbaabc71b872087c1719c5380d3e185173 Mon Sep 17 00:00:00 2001 From: Diego Becciolini Date: Mon, 2 Oct 2017 11:00:25 +0100 Subject: [PATCH] Improve consistency of flags when using SetGlobalNormalizationFunc (#522) Fix #521 --- cobra_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-- command.go | 16 +++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/cobra_test.go b/cobra_test.go index d5df951e..8192b526 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -190,6 +190,7 @@ func flagInit() { cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") cmdTimes.Flags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) + cmdTimes.LocalFlags() // populate lflags before parent is set cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree") cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") } @@ -210,8 +211,8 @@ func initialize() *Command { rootPersPre, echoPre, echoPersPre, timesPersPre = nil, nil, nil, nil var c = cmdRootNoRun - flagInit() commandInit() + flagInit() return c } @@ -219,8 +220,8 @@ func initializeWithSameName() *Command { tt, tp, te = nil, nil, nil rootPersPre, echoPre, echoPersPre, timesPersPre = nil, nil, nil, nil var c = cmdRootSameName - flagInit() commandInit() + flagInit() return c } @@ -910,6 +911,7 @@ func TestRootHelp(t *testing.T) { func TestFlagAccess(t *testing.T) { initialize() + cmdEcho.AddCommand(cmdTimes) local := cmdTimes.LocalFlags() inherited := cmdTimes.InheritedFlags() @@ -1165,11 +1167,18 @@ func TestGlobalNormFuncPropagation(t *testing.T) { } rootCmd := initialize() + rootCmd.AddCommand(cmdEcho) + rootCmd.SetGlobalNormalizationFunc(normFunc) if reflect.ValueOf(normFunc).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() { t.Error("rootCmd seems to have a wrong normalization function") } + // Also check it propagates retroactively + if reflect.ValueOf(normFunc).Pointer() != reflect.ValueOf(cmdEcho.GlobalNormalizationFunc()).Pointer() { + t.Error("cmdEcho should have had the normalization function of rootCmd") + } + // First add the cmdEchoSub to cmdPrint cmdPrint.AddCommand(cmdEchoSub) if cmdPrint.GlobalNormalizationFunc() != nil && cmdEchoSub.GlobalNormalizationFunc() != nil { @@ -1184,6 +1193,67 @@ func TestGlobalNormFuncPropagation(t *testing.T) { } } +func TestNormPassedOnLocal(t *testing.T) { + n := func(f *pflag.FlagSet, name string) pflag.NormalizedName { + return pflag.NormalizedName(strings.ToUpper(name)) + } + + cmd := &Command{} + flagVal := false + + cmd.Flags().BoolVar(&flagVal, "flagname", true, "this is a dummy flag") + cmd.SetGlobalNormalizationFunc(n) + if cmd.LocalFlags().Lookup("flagname") != cmd.LocalFlags().Lookup("FLAGNAME") { + t.Error("Normalization function should be passed on to Local flag set") + } +} + +func TestNormPassedOnInherited(t *testing.T) { + n := func(f *pflag.FlagSet, name string) pflag.NormalizedName { + return pflag.NormalizedName(strings.ToUpper(name)) + } + + cmd, childBefore, childAfter := &Command{}, &Command{}, &Command{} + flagVal := false + cmd.AddCommand(childBefore) + + cmd.PersistentFlags().BoolVar(&flagVal, "flagname", true, "this is a dummy flag") + cmd.SetGlobalNormalizationFunc(n) + + cmd.AddCommand(childAfter) + + if f := childBefore.InheritedFlags(); f.Lookup("flagname") == nil || f.Lookup("flagname") != f.Lookup("FLAGNAME") { + t.Error("Normalization function should be passed on to inherited flag set in command added before flag") + } + if f := childAfter.InheritedFlags(); f.Lookup("flagname") == nil || f.Lookup("flagname") != f.Lookup("FLAGNAME") { + t.Error("Normalization function should be passed on to inherited flag set in command added after flag") + } +} + +// Related to https://github.com/spf13/cobra/issues/521. +func TestNormConsistent(t *testing.T) { + n := func(f *pflag.FlagSet, name string) pflag.NormalizedName { + return pflag.NormalizedName(strings.ToUpper(name)) + } + id := func(f *pflag.FlagSet, name string) pflag.NormalizedName { + return pflag.NormalizedName(name) + } + + cmd := &Command{} + flagVal := false + + cmd.Flags().BoolVar(&flagVal, "flagname", true, "this is a dummy flag") + // Build local flag set + cmd.LocalFlags() + + cmd.SetGlobalNormalizationFunc(n) + cmd.SetGlobalNormalizationFunc(id) + + if cmd.LocalFlags().Lookup("flagname") == cmd.LocalFlags().Lookup("FLAGNAME") { + t.Error("Normalizing flag names should not result in duplicate flags") + } +} + func TestFlagOnPflagCommandLine(t *testing.T) { flagName := "flagOnCommandLine" pflag.String(flagName, "", "about my flag") diff --git a/command.go b/command.go index 185e4526..a0a633d1 100644 --- a/command.go +++ b/command.go @@ -806,6 +806,7 @@ Simply type ` + c.Name() + ` help [path to command] for full details.`, // ResetCommands used for testing. func (c *Command) ResetCommands() { + c.parent = nil c.commands = nil c.helpCommand = nil c.parentsPflags = nil @@ -1132,6 +1133,9 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.lflags.SetOutput(c.flagErrorBuf) } c.lflags.SortFlags = c.Flags().SortFlags + if c.globNormFunc != nil { + c.lflags.SetNormalizeFunc(c.globNormFunc) + } addToLocal := func(f *flag.Flag) { if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { @@ -1156,6 +1160,10 @@ func (c *Command) InheritedFlags() *flag.FlagSet { } local := c.LocalFlags() + if c.globNormFunc != nil { + c.iflags.SetNormalizeFunc(c.globNormFunc) + } + c.parentsPflags.VisitAll(func(f *flag.Flag) { if c.iflags.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil { c.iflags.AddFlag(f) @@ -1189,6 +1197,10 @@ func (c *Command) ResetFlags() { c.flags.SetOutput(c.flagErrorBuf) c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.pflags.SetOutput(c.flagErrorBuf) + + c.lflags = nil + c.iflags = nil + c.parentsPflags = nil } // HasFlags checks if the command contains any flags (local plus persistent from the entire structure). @@ -1298,6 +1310,10 @@ func (c *Command) updateParentsPflags() { c.parentsPflags.SortFlags = false } + if c.globNormFunc != nil { + c.parentsPflags.SetNormalizeFunc(c.globNormFunc) + } + c.Root().PersistentFlags().AddFlagSet(flag.CommandLine) c.VisitParents(func(parent *Command) {