Merge defaults when Get()'ing maps

When fetching a configuration sub-tree (either via Get() that results
in a map, or through Sub()), ensure that any defaults for keys under
that tree are propagated into the returned result.

Fixes #747

Signed-off-by: setrofim <setrofim@gmail.com>
This commit is contained in:
setrofim 2022-09-28 18:04:52 +01:00
parent f1d2c470bf
commit 1c3e461869
3 changed files with 92 additions and 35 deletions

View file

@ -46,7 +46,7 @@ func TestNestedOverrides(t *testing.T) {
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
// Case 4: key:value overridden by a map // Case 4: key:value overridden by a map
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} assert.Equal(map[string]interface{}{"size": 4, "age": 10}, v.Get("tom")) // "tom.size" is first given "4" as default value, then "tom" is set to map{"age":10}, size and age should be merged.
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
assert.Equal(10, v.Get("tom.age")) // new value should be there assert.Equal(10, v.Get("tom.age")) // new value should be there
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there

100
viper.go
View file

@ -886,11 +886,14 @@ func Get(key string) interface{} { return v.Get(key) }
func (v *Viper) Get(key string) interface{} { func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key) lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, true)
if val == nil { found := v.find(lcaseKey, true)
if len(found) == 0 {
return nil return nil
} }
val := v.mergeFoundMaps(lcaseKey, found)
if v.typeByDefValue { if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test. // TODO(bep) this branch isn't covered by a single test.
valType := val valType := val
@ -1226,9 +1229,10 @@ func (v *Viper) MustBindEnv(input ...string) {
// corresponds to a flag, the flag's default value is returned. // corresponds to a flag, the flag's default value is returned.
// //
// Note: this assumes a lower-cased key given. // Note: this assumes a lower-cased key given.
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { func (v *Viper) find(lcaseKey string, flagDefault bool) []interface{} {
var ( var (
val interface{} val interface{}
found []interface{}
exists bool exists bool
path = strings.Split(lcaseKey, v.keyDelim) path = strings.Split(lcaseKey, v.keyDelim)
nested = len(path) > 1 nested = len(path) > 1
@ -1247,10 +1251,10 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
// Set() override first // Set() override first
val = v.searchMap(v.override, path) val = v.searchMap(v.override, path)
if val != nil { if val != nil {
return val found = append(found, val)
} }
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" { if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
return nil return found
} }
// PFlag override next // PFlag override next
@ -1258,27 +1262,27 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
if exists && flag.HasChanged() { if exists && flag.HasChanged() {
switch flag.ValueType() { switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString()) found = append(found, cast.ToInt(flag.ValueString()))
case "bool": case "bool":
return cast.ToBool(flag.ValueString()) found = append(found, cast.ToBool(flag.ValueString()))
case "stringSlice", "stringArray": case "stringSlice", "stringArray":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s) res, _ := readAsCSV(s)
return res found = append(found, res)
case "intSlice": case "intSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s) res, _ := readAsCSV(s)
return cast.ToIntSlice(res) found = append(found, cast.ToIntSlice(res))
case "stringToString": case "stringToString":
return stringToStringConv(flag.ValueString()) found = append(found, stringToStringConv(flag.ValueString()))
default: default:
return flag.ValueString() found = append(found, flag.ValueString())
} }
} }
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" { if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
return nil return found
} }
// Env override next // Env override next
@ -1286,49 +1290,50 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
// even if it hasn't been registered, if automaticEnv is used, // even if it hasn't been registered, if automaticEnv is used,
// check any Get request // check any Get request
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok { if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
return val found = append(found, val)
} }
if nested && v.isPathShadowedInAutoEnv(path) != "" { if nested && v.isPathShadowedInAutoEnv(path) != "" {
return nil return found
} }
} }
envkeys, exists := v.env[lcaseKey] envkeys, exists := v.env[lcaseKey]
if exists { if exists {
for _, envkey := range envkeys { for _, envkey := range envkeys {
if val, ok := v.getEnv(envkey); ok { if val, ok := v.getEnv(envkey); ok {
return val found = append(found, val)
break
} }
} }
} }
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
return nil return found
} }
// Config file next // Config file next
val = v.searchIndexableWithPathPrefixes(v.config, path) val = v.searchIndexableWithPathPrefixes(v.config, path)
if val != nil { if val != nil {
return val found = append(found, val)
} }
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" { if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
return nil return found
} }
// K/V store next // K/V store next
val = v.searchMap(v.kvstore, path) val = v.searchMap(v.kvstore, path)
if val != nil { if val != nil {
return val found = append(found, val)
} }
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" { if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
return nil return found
} }
// Default next // Default next
val = v.searchMap(v.defaults, path) val = v.searchMap(v.defaults, path)
if val != nil { if val != nil {
return val found = append(found, val)
} }
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" { if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
return nil return found
} }
if flagDefault { if flagDefault {
@ -1337,31 +1342,66 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
if flag, exists := v.pflags[lcaseKey]; exists { if flag, exists := v.pflags[lcaseKey]; exists {
switch flag.ValueType() { switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString()) found = append(found, cast.ToInt(flag.ValueString()))
case "bool": case "bool":
return cast.ToBool(flag.ValueString()) found = append(found, cast.ToBool(flag.ValueString()))
case "stringSlice", "stringArray": case "stringSlice", "stringArray":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s) res, _ := readAsCSV(s)
return res found = append(found, res)
case "intSlice": case "intSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s) res, _ := readAsCSV(s)
return cast.ToIntSlice(res) found = append(found, cast.ToIntSlice(res))
case "stringToString": case "stringToString":
return stringToStringConv(flag.ValueString()) found = append(found, stringToStringConv(flag.ValueString()))
default: default:
return flag.ValueString() found = append(found, flag.ValueString())
} }
} }
// last item, no need to check shadowing // last item, no need to check shadowing
} }
return found
}
// found is a slice of values found across layers for lcaseKey in priority
// order (i.e. override, flag, env, config file, key/value store, default). If
// the highest priority value found is a map, traverse back across found
// values, and, while they are "mappable", merged them into the overall result.
func (v *Viper) mergeFoundMaps(lcaseKey string, found []interface{}) interface{} {
if len(found) == 0 {
return nil return nil
} }
var foundMaps []map[string]interface{}
for _, fv := range found {
fm, err := cast.ToStringMapE(fv)
if err != nil {
// A non-map value found. This shadows everything else
// further down the list.
break
}
foundMaps = append(foundMaps, fm)
}
// No non-shadowed maps found. Return the highest priority value.
if len(foundMaps) == 0 {
return found[0]
}
mergedMap := map[string]interface{}{}
// merge in reversed order (so that higher priority sources overwrite lower priority ones)
for i := len(foundMaps) - 1; i >= 0; i = i - 1 {
mergeMaps(foundMaps[i], mergedMap, nil)
}
return mergedMap
}
func readAsCSV(val string) ([]string, error) { func readAsCSV(val string) ([]string, error) {
if val == "" { if val == "" {
return []string{}, nil return []string{}, nil
@ -1401,8 +1441,8 @@ func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool { func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key) lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, false) vals := v.find(lcaseKey, false)
return val != nil return len(vals) != 0
} }
// AutomaticEnv makes Viper check if environment variables match any of the existing keys // AutomaticEnv makes Viper check if environment variables match any of the existing keys

View file

@ -2569,6 +2569,23 @@ func TestSliceIndexAccess(t *testing.T) {
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
} }
func TestMapWithDefaults(t *testing.T) {
Set("config.value1", 1)
SetDefault("config.value2.internal", 3)
m, ok := Get("config").(map[string]interface{})
require.True(t, ok)
assert.Equal(t, map[string]interface{}{"internal": 3}, m["value2"])
}
func TestSubWithDefaults(t *testing.T) {
Set("config.value1", 1)
SetDefault("config.value2.internal", 3)
sub := Sub("config")
assert.Equal(t, 3, sub.Get("value2.internal"))
}
func BenchmarkGetBool(b *testing.B) { func BenchmarkGetBool(b *testing.B) {
key := "BenchmarkGetBool" key := "BenchmarkGetBool"
v = New() v = New()