Merge defaults when Get()'ing maps

When fetching a configuration sub-tree (either via Get() that results
in a map, or through Sub()), ensure that any defaults for keys under
that tree are propagated into the returned result.

Fixes #747

Signed-off-by: setrofim <setrofim@gmail.com>
This commit is contained in:
setrofim 2022-09-28 18:04:52 +01:00
parent f1d2c470bf
commit 1c3e461869
3 changed files with 92 additions and 35 deletions

View file

@ -46,7 +46,7 @@ func TestNestedOverrides(t *testing.T) {
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
// Case 4: key:value overridden by a map
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
assert.Equal(map[string]interface{}{"size": 4, "age": 10}, v.Get("tom")) // "tom.size" is first given "4" as default value, then "tom" is set to map{"age":10}, size and age should be merged.
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
assert.Equal(10, v.Get("tom.age")) // new value should be there
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there

100
viper.go
View file

@ -886,11 +886,14 @@ func Get(key string) interface{} { return v.Get(key) }
func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, true)
if val == nil {
found := v.find(lcaseKey, true)
if len(found) == 0 {
return nil
}
val := v.mergeFoundMaps(lcaseKey, found)
if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test.
valType := val
@ -1226,9 +1229,10 @@ func (v *Viper) MustBindEnv(input ...string) {
// corresponds to a flag, the flag's default value is returned.
//
// Note: this assumes a lower-cased key given.
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
func (v *Viper) find(lcaseKey string, flagDefault bool) []interface{} {
var (
val interface{}
found []interface{}
exists bool
path = strings.Split(lcaseKey, v.keyDelim)
nested = len(path) > 1
@ -1247,10 +1251,10 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
// Set() override first
val = v.searchMap(v.override, path)
if val != nil {
return val
found = append(found, val)
}
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
return nil
return found
}
// PFlag override next
@ -1258,27 +1262,27 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
if exists && flag.HasChanged() {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString())
found = append(found, cast.ToInt(flag.ValueString()))
case "bool":
return cast.ToBool(flag.ValueString())
found = append(found, cast.ToBool(flag.ValueString()))
case "stringSlice", "stringArray":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
found = append(found, res)
case "intSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return cast.ToIntSlice(res)
found = append(found, cast.ToIntSlice(res))
case "stringToString":
return stringToStringConv(flag.ValueString())
found = append(found, stringToStringConv(flag.ValueString()))
default:
return flag.ValueString()
found = append(found, flag.ValueString())
}
}
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
return nil
return found
}
// Env override next
@ -1286,49 +1290,50 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
// even if it hasn't been registered, if automaticEnv is used,
// check any Get request
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
return val
found = append(found, val)
}
if nested && v.isPathShadowedInAutoEnv(path) != "" {
return nil
return found
}
}
envkeys, exists := v.env[lcaseKey]
if exists {
for _, envkey := range envkeys {
if val, ok := v.getEnv(envkey); ok {
return val
found = append(found, val)
break
}
}
}
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
return nil
return found
}
// Config file next
val = v.searchIndexableWithPathPrefixes(v.config, path)
if val != nil {
return val
found = append(found, val)
}
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
return nil
return found
}
// K/V store next
val = v.searchMap(v.kvstore, path)
if val != nil {
return val
found = append(found, val)
}
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
return nil
return found
}
// Default next
val = v.searchMap(v.defaults, path)
if val != nil {
return val
found = append(found, val)
}
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
return nil
return found
}
if flagDefault {
@ -1337,29 +1342,64 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
if flag, exists := v.pflags[lcaseKey]; exists {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString())
found = append(found, cast.ToInt(flag.ValueString()))
case "bool":
return cast.ToBool(flag.ValueString())
found = append(found, cast.ToBool(flag.ValueString()))
case "stringSlice", "stringArray":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
found = append(found, res)
case "intSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return cast.ToIntSlice(res)
found = append(found, cast.ToIntSlice(res))
case "stringToString":
return stringToStringConv(flag.ValueString())
found = append(found, stringToStringConv(flag.ValueString()))
default:
return flag.ValueString()
found = append(found, flag.ValueString())
}
}
// last item, no need to check shadowing
}
return found
}
// found is a slice of values found across layers for lcaseKey in priority
// order (i.e. override, flag, env, config file, key/value store, default). If
// the highest priority value found is a map, traverse back across found
// values, and, while they are "mappable", merged them into the overall result.
func (v *Viper) mergeFoundMaps(lcaseKey string, found []interface{}) interface{} {
if len(found) == 0 {
return nil
}
var foundMaps []map[string]interface{}
for _, fv := range found {
fm, err := cast.ToStringMapE(fv)
if err != nil {
// A non-map value found. This shadows everything else
// further down the list.
break
}
foundMaps = append(foundMaps, fm)
}
// No non-shadowed maps found. Return the highest priority value.
if len(foundMaps) == 0 {
return found[0]
}
mergedMap := map[string]interface{}{}
// merge in reversed order (so that higher priority sources overwrite lower priority ones)
for i := len(foundMaps) - 1; i >= 0; i = i - 1 {
mergeMaps(foundMaps[i], mergedMap, nil)
}
return mergedMap
}
func readAsCSV(val string) ([]string, error) {
@ -1401,8 +1441,8 @@ func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, false)
return val != nil
vals := v.find(lcaseKey, false)
return len(vals) != 0
}
// AutomaticEnv makes Viper check if environment variables match any of the existing keys

View file

@ -2569,6 +2569,23 @@ func TestSliceIndexAccess(t *testing.T) {
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
}
func TestMapWithDefaults(t *testing.T) {
Set("config.value1", 1)
SetDefault("config.value2.internal", 3)
m, ok := Get("config").(map[string]interface{})
require.True(t, ok)
assert.Equal(t, map[string]interface{}{"internal": 3}, m["value2"])
}
func TestSubWithDefaults(t *testing.T) {
Set("config.value1", 1)
SetDefault("config.value2.internal", 3)
sub := Sub("config")
assert.Equal(t, 3, sub.Get("value2.internal"))
}
func BenchmarkGetBool(b *testing.B) {
key := "BenchmarkGetBool"
v = New()