From 1c3e461869621f353416e41ba3b984b110d888ee Mon Sep 17 00:00:00 2001 From: setrofim Date: Wed, 28 Sep 2022 18:04:52 +0100 Subject: [PATCH] Merge defaults when Get()'ing maps When fetching a configuration sub-tree (either via Get() that results in a map, or through Sub()), ensure that any defaults for keys under that tree are propagated into the returned result. Fixes #747 Signed-off-by: setrofim --- overrides_test.go | 8 ++-- viper.go | 102 ++++++++++++++++++++++++++++++++-------------- viper_test.go | 17 ++++++++ 3 files changed, 92 insertions(+), 35 deletions(-) diff --git a/overrides_test.go b/overrides_test.go index 8048204..5266e5f 100644 --- a/overrides_test.go +++ b/overrides_test.go @@ -46,10 +46,10 @@ func TestNestedOverrides(t *testing.T) { deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) // Case 4: key:value overridden by a map - v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} - assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable - assert.Equal(10, v.Get("tom.age")) // new value should be there - deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there + assert.Equal(map[string]interface{}{"size": 4, "age": 10}, v.Get("tom")) // "tom.size" is first given "4" as default value, then "tom" is set to map{"age":10}, size and age should be merged. + assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable + assert.Equal(10, v.Get("tom.age")) // new value should be there + deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) assert.Nil(v.Get("tom.size")) assert.Equal(10, v.Get("tom.age")) diff --git a/viper.go b/viper.go index 5f76cc0..f942c83 100644 --- a/viper.go +++ b/viper.go @@ -886,11 +886,14 @@ func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey, true) - if val == nil { + + found := v.find(lcaseKey, true) + if len(found) == 0 { return nil } + val := v.mergeFoundMaps(lcaseKey, found) + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val @@ -1226,9 +1229,10 @@ func (v *Viper) MustBindEnv(input ...string) { // corresponds to a flag, the flag's default value is returned. // // Note: this assumes a lower-cased key given. -func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { +func (v *Viper) find(lcaseKey string, flagDefault bool) []interface{} { var ( val interface{} + found []interface{} exists bool path = strings.Split(lcaseKey, v.keyDelim) nested = len(path) > 1 @@ -1247,10 +1251,10 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { // Set() override first val = v.searchMap(v.override, path) if val != nil { - return val + found = append(found, val) } if nested && v.isPathShadowedInDeepMap(path, v.override) != "" { - return nil + return found } // PFlag override next @@ -1258,27 +1262,27 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { if exists && flag.HasChanged() { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.ValueString()) + found = append(found, cast.ToInt(flag.ValueString())) case "bool": - return cast.ToBool(flag.ValueString()) + found = append(found, cast.ToBool(flag.ValueString())) case "stringSlice", "stringArray": s := strings.TrimPrefix(flag.ValueString(), "[") s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) - return res + found = append(found, res) case "intSlice": s := strings.TrimPrefix(flag.ValueString(), "[") s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) - return cast.ToIntSlice(res) + found = append(found, cast.ToIntSlice(res)) case "stringToString": - return stringToStringConv(flag.ValueString()) + found = append(found, stringToStringConv(flag.ValueString())) default: - return flag.ValueString() + found = append(found, flag.ValueString()) } } if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" { - return nil + return found } // Env override next @@ -1286,49 +1290,50 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { // even if it hasn't been registered, if automaticEnv is used, // check any Get request if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok { - return val + found = append(found, val) } if nested && v.isPathShadowedInAutoEnv(path) != "" { - return nil + return found } } envkeys, exists := v.env[lcaseKey] if exists { for _, envkey := range envkeys { if val, ok := v.getEnv(envkey); ok { - return val + found = append(found, val) + break } } } if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { - return nil + return found } // Config file next val = v.searchIndexableWithPathPrefixes(v.config, path) if val != nil { - return val + found = append(found, val) } if nested && v.isPathShadowedInDeepMap(path, v.config) != "" { - return nil + return found } // K/V store next val = v.searchMap(v.kvstore, path) if val != nil { - return val + found = append(found, val) } if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" { - return nil + return found } // Default next val = v.searchMap(v.defaults, path) if val != nil { - return val + found = append(found, val) } if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" { - return nil + return found } if flagDefault { @@ -1337,29 +1342,64 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { if flag, exists := v.pflags[lcaseKey]; exists { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.ValueString()) + found = append(found, cast.ToInt(flag.ValueString())) case "bool": - return cast.ToBool(flag.ValueString()) + found = append(found, cast.ToBool(flag.ValueString())) case "stringSlice", "stringArray": s := strings.TrimPrefix(flag.ValueString(), "[") s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) - return res + found = append(found, res) case "intSlice": s := strings.TrimPrefix(flag.ValueString(), "[") s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) - return cast.ToIntSlice(res) + found = append(found, cast.ToIntSlice(res)) case "stringToString": - return stringToStringConv(flag.ValueString()) + found = append(found, stringToStringConv(flag.ValueString())) default: - return flag.ValueString() + found = append(found, flag.ValueString()) } } // last item, no need to check shadowing } - return nil + return found +} + +// found is a slice of values found across layers for lcaseKey in priority +// order (i.e. override, flag, env, config file, key/value store, default). If +// the highest priority value found is a map, traverse back across found +// values, and, while they are "mappable", merged them into the overall result. +func (v *Viper) mergeFoundMaps(lcaseKey string, found []interface{}) interface{} { + if len(found) == 0 { + return nil + } + + var foundMaps []map[string]interface{} + for _, fv := range found { + fm, err := cast.ToStringMapE(fv) + if err != nil { + // A non-map value found. This shadows everything else + // further down the list. + break + } + foundMaps = append(foundMaps, fm) + } + + // No non-shadowed maps found. Return the highest priority value. + if len(foundMaps) == 0 { + return found[0] + } + + mergedMap := map[string]interface{}{} + + // merge in reversed order (so that higher priority sources overwrite lower priority ones) + for i := len(foundMaps) - 1; i >= 0; i = i - 1 { + mergeMaps(foundMaps[i], mergedMap, nil) + } + + return mergedMap } func readAsCSV(val string) ([]string, error) { @@ -1401,8 +1441,8 @@ func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey, false) - return val != nil + vals := v.find(lcaseKey, false) + return len(vals) != 0 } // AutomaticEnv makes Viper check if environment variables match any of the existing keys diff --git a/viper_test.go b/viper_test.go index 926ffc2..57be581 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2569,6 +2569,23 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +func TestMapWithDefaults(t *testing.T) { + Set("config.value1", 1) + SetDefault("config.value2.internal", 3) + + m, ok := Get("config").(map[string]interface{}) + require.True(t, ok) + assert.Equal(t, map[string]interface{}{"internal": 3}, m["value2"]) +} + +func TestSubWithDefaults(t *testing.T) { + Set("config.value1", 1) + SetDefault("config.value2.internal", 3) + + sub := Sub("config") + assert.Equal(t, 3, sub.Get("value2.internal")) +} + func BenchmarkGetBool(b *testing.B) { key := "BenchmarkGetBool" v = New()