diff --git a/viper.go b/viper.go index 7b12b36..bbfe6b1 100644 --- a/viper.go +++ b/viper.go @@ -780,11 +780,21 @@ func (v *Viper) Sub(key string) *Viper { return nil } - if reflect.TypeOf(data).Kind() == reflect.Map { - subv.config = cast.ToStringMap(data) - return subv + if !(reflect.TypeOf(data).Kind() == reflect.Map) { + return nil } - return nil + subv.config = cast.ToStringMap(data) + subPFlags := make(map[string]FlagValue) + for flagName, flagValue := range v.pflags { + keyPrefix := key + "." + if !strings.HasPrefix(flagName, keyPrefix) { + continue + } + newFlagName := flagName[len(keyPrefix):] + subPFlags[newFlagName] = flagValue + } + subv.pflags = subPFlags + return subv } // GetString returns the value associated with the key as a string. diff --git a/viper_test.go b/viper_test.go index b8ceccb..5dc70c3 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1253,6 +1253,32 @@ func TestSub(t *testing.T) { assert.Equal(t, (*Viper)(nil), subv) } +func TestSubWithBoundedFlags(t *testing.T) { + v := New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(yamlExample)) + + overridenValue := "notSoLarge" + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String("clothing.pants.size", "", "") + assert.NoError(t, flags.Parse([]string{"--clothing.pants.size=" + overridenValue})) + + v.BindPFlags(flags) + subv := v.Sub("clothing") + assert.Equal(t, overridenValue, v.Get("clothing.pants.size")) + assert.Equal(t, overridenValue, subv.Get("pants.size")) + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + + subv = v.Sub("clothing.pants") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) + + subv = v.Sub("clothing.pants.size") + assert.Equal(t, (*Viper)(nil), subv) + + subv = v.Sub("missing.key") + assert.Equal(t, (*Viper)(nil), subv) +} + var hclWriteExpected = []byte(`"foos" = { "foo" = { "key" = 1