Copy parsed flag values when constructing LocalFlags()

Fixes #1019
This commit is contained in:
Benjamin Congdon 2020-01-15 19:01:26 -08:00
parent 89c7ffb512
commit 24481f04da
No known key found for this signature in database
GPG key ID: 148FEE4C2C805D4A
2 changed files with 48 additions and 0 deletions

View file

@ -1418,6 +1418,13 @@ func (c *Command) LocalFlags() *flag.FlagSet {
addToLocal := func(f *flag.Flag) { addToLocal := func(f *flag.Flag) {
if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil { if c.lflags.Lookup(f.Name) == nil && c.parentsPflags.Lookup(f.Name) == nil {
c.lflags.AddFlag(f) c.lflags.AddFlag(f)
// If `f` has already been changed, re-add its value to the flag set,
// otherwise `c.lflags` has no record of the flags value.
if f.Changed {
f.Changed = false
c.lflags.Set(f.Name, f.Value.String())
}
} }
} }
c.Flags().VisitAll(addToLocal) c.Flags().VisitAll(addToLocal)

View file

@ -1773,3 +1773,44 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) {
} }
checkStringContains(t, output, "unknown flag: --unknown") checkStringContains(t, output, "unknown flag: --unknown")
} }
func TestLocalFlagsInChildRun(t *testing.T) {
root := &Command{
Use: "root",
Run: emptyRun,
}
root.Flags().BoolP("boola", "a", false, "a boolean flag")
var setFlags []string
var allFlags []string
c := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
cmd.LocalFlags().Visit(func(f *pflag.Flag) {
setFlags = append(setFlags, f.Name)
})
cmd.LocalFlags().VisitAll(func(f *pflag.Flag) {
allFlags = append(allFlags, f.Name)
})
},
}
c.Flags().BoolP("boolb", "b", false, "a boolean flag")
c.Flags().BoolP("boolc", "c", false, "a boolean flag")
root.AddCommand(c)
_, err := executeCommand(root, "child", "--boolb")
if err != nil {
t.Fatal("unexpected error: ", err.Error())
}
if len(setFlags) != 1 || setFlags[0] != "boolb" {
t.Errorf(`expected setFlags to be ["boolb"], but was: %v`, setFlags)
}
expectedAllFlags := []string{"boolb", "boolc", "help"}
for i, f := range expectedAllFlags {
if allFlags[i] != f {
t.Errorf("Expected: %s, got: %s", f, allFlags[i])
}
}
}