From 24481f04da779874d1dd067f56c928b2590a52a7 Mon Sep 17 00:00:00 2001 From: Benjamin Congdon Date: Wed, 15 Jan 2020 19:01:26 -0800 Subject: [PATCH] Copy parsed flag values when constructing `LocalFlags()` Fixes #1019 --- command.go | 7 +++++++ command_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/command.go b/command.go index ab3cf69a..bed18846 100644 --- a/command.go +++ b/command.go @@ -1418,6 +1418,13 @@ func (c *Command) LocalFlags() *flag.FlagSet { addToLocal := func(f *flag.Flag) { if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { c.lflags.AddFlag(f) + + // If `f` has already been changed, re-add its value to the flag set, + // otherwise `c.lflags` has no record of the flags value. + if f.Changed { + f.Changed = false + c.lflags.Set(f.Name, f.Value.String()) + } } } c.Flags().VisitAll(addToLocal) diff --git a/command_test.go b/command_test.go index b26bd4ab..f612511d 100644 --- a/command_test.go +++ b/command_test.go @@ -1773,3 +1773,44 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { } checkStringContains(t, output, "unknown flag: --unknown") } + +func TestLocalFlagsInChildRun(t *testing.T) { + root := &Command{ + Use: "root", + Run: emptyRun, + } + root.Flags().BoolP("boola", "a", false, "a boolean flag") + + var setFlags []string + var allFlags []string + c := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + cmd.LocalFlags().Visit(func(f *pflag.Flag) { + setFlags = append(setFlags, f.Name) + }) + cmd.LocalFlags().VisitAll(func(f *pflag.Flag) { + allFlags = append(allFlags, f.Name) + }) + }, + } + c.Flags().BoolP("boolb", "b", false, "a boolean flag") + c.Flags().BoolP("boolc", "c", false, "a boolean flag") + + root.AddCommand(c) + + _, err := executeCommand(root, "child", "--boolb") + if err != nil { + t.Fatal("unexpected error: ", err.Error()) + } + + if len(setFlags) != 1 || setFlags[0] != "boolb" { + t.Errorf(`expected setFlags to be ["boolb"], but was: %v`, setFlags) + } + expectedAllFlags := []string{"boolb", "boolc", "help"} + for i, f := range expectedAllFlags { + if allFlags[i] != f { + t.Errorf("Expected: %s, got: %s", f, allFlags[i]) + } + } +}