mirror of
https://github.com/spf13/viper
synced 2025-01-07 11:16:37 +00:00
Merge defaults when Get()'ing maps
When fetching a configuration sub-tree (either via Get() that results in a map, or through Sub()), ensure that any defaults for keys under that tree are propagated into the returned result. Fixes #747 Signed-off-by: setrofim <setrofim@gmail.com>
This commit is contained in:
parent
f1d2c470bf
commit
1c3e461869
3 changed files with 92 additions and 35 deletions
|
@ -46,10 +46,10 @@ func TestNestedOverrides(t *testing.T) {
|
|||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
|
||||
|
||||
// Case 4: key:value overridden by a map
|
||||
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
|
||||
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
||||
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
||||
assert.Equal(map[string]interface{}{"size": 4, "age": 10}, v.Get("tom")) // "tom.size" is first given "4" as default value, then "tom" is set to map{"age":10}, size and age should be merged.
|
||||
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
||||
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
||||
v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
|
||||
assert.Nil(v.Get("tom.size"))
|
||||
assert.Equal(10, v.Get("tom.age"))
|
||||
|
|
102
viper.go
102
viper.go
|
@ -886,11 +886,14 @@ func Get(key string) interface{} { return v.Get(key) }
|
|||
|
||||
func (v *Viper) Get(key string) interface{} {
|
||||
lcaseKey := strings.ToLower(key)
|
||||
val := v.find(lcaseKey, true)
|
||||
if val == nil {
|
||||
|
||||
found := v.find(lcaseKey, true)
|
||||
if len(found) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := v.mergeFoundMaps(lcaseKey, found)
|
||||
|
||||
if v.typeByDefValue {
|
||||
// TODO(bep) this branch isn't covered by a single test.
|
||||
valType := val
|
||||
|
@ -1226,9 +1229,10 @@ func (v *Viper) MustBindEnv(input ...string) {
|
|||
// corresponds to a flag, the flag's default value is returned.
|
||||
//
|
||||
// Note: this assumes a lower-cased key given.
|
||||
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
||||
func (v *Viper) find(lcaseKey string, flagDefault bool) []interface{} {
|
||||
var (
|
||||
val interface{}
|
||||
found []interface{}
|
||||
exists bool
|
||||
path = strings.Split(lcaseKey, v.keyDelim)
|
||||
nested = len(path) > 1
|
||||
|
@ -1247,10 +1251,10 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
|||
// Set() override first
|
||||
val = v.searchMap(v.override, path)
|
||||
if val != nil {
|
||||
return val
|
||||
found = append(found, val)
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// PFlag override next
|
||||
|
@ -1258,27 +1262,27 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
|||
if exists && flag.HasChanged() {
|
||||
switch flag.ValueType() {
|
||||
case "int", "int8", "int16", "int32", "int64":
|
||||
return cast.ToInt(flag.ValueString())
|
||||
found = append(found, cast.ToInt(flag.ValueString()))
|
||||
case "bool":
|
||||
return cast.ToBool(flag.ValueString())
|
||||
found = append(found, cast.ToBool(flag.ValueString()))
|
||||
case "stringSlice", "stringArray":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return res
|
||||
found = append(found, res)
|
||||
case "intSlice":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return cast.ToIntSlice(res)
|
||||
found = append(found, cast.ToIntSlice(res))
|
||||
case "stringToString":
|
||||
return stringToStringConv(flag.ValueString())
|
||||
found = append(found, stringToStringConv(flag.ValueString()))
|
||||
default:
|
||||
return flag.ValueString()
|
||||
found = append(found, flag.ValueString())
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// Env override next
|
||||
|
@ -1286,49 +1290,50 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
|||
// even if it hasn't been registered, if automaticEnv is used,
|
||||
// check any Get request
|
||||
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
||||
return val
|
||||
found = append(found, val)
|
||||
}
|
||||
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
}
|
||||
envkeys, exists := v.env[lcaseKey]
|
||||
if exists {
|
||||
for _, envkey := range envkeys {
|
||||
if val, ok := v.getEnv(envkey); ok {
|
||||
return val
|
||||
found = append(found, val)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// Config file next
|
||||
val = v.searchIndexableWithPathPrefixes(v.config, path)
|
||||
if val != nil {
|
||||
return val
|
||||
found = append(found, val)
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// K/V store next
|
||||
val = v.searchMap(v.kvstore, path)
|
||||
if val != nil {
|
||||
return val
|
||||
found = append(found, val)
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// Default next
|
||||
val = v.searchMap(v.defaults, path)
|
||||
if val != nil {
|
||||
return val
|
||||
found = append(found, val)
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
if flagDefault {
|
||||
|
@ -1337,29 +1342,64 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
|||
if flag, exists := v.pflags[lcaseKey]; exists {
|
||||
switch flag.ValueType() {
|
||||
case "int", "int8", "int16", "int32", "int64":
|
||||
return cast.ToInt(flag.ValueString())
|
||||
found = append(found, cast.ToInt(flag.ValueString()))
|
||||
case "bool":
|
||||
return cast.ToBool(flag.ValueString())
|
||||
found = append(found, cast.ToBool(flag.ValueString()))
|
||||
case "stringSlice", "stringArray":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return res
|
||||
found = append(found, res)
|
||||
case "intSlice":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return cast.ToIntSlice(res)
|
||||
found = append(found, cast.ToIntSlice(res))
|
||||
case "stringToString":
|
||||
return stringToStringConv(flag.ValueString())
|
||||
found = append(found, stringToStringConv(flag.ValueString()))
|
||||
default:
|
||||
return flag.ValueString()
|
||||
found = append(found, flag.ValueString())
|
||||
}
|
||||
}
|
||||
// last item, no need to check shadowing
|
||||
}
|
||||
|
||||
return nil
|
||||
return found
|
||||
}
|
||||
|
||||
// found is a slice of values found across layers for lcaseKey in priority
|
||||
// order (i.e. override, flag, env, config file, key/value store, default). If
|
||||
// the highest priority value found is a map, traverse back across found
|
||||
// values, and, while they are "mappable", merged them into the overall result.
|
||||
func (v *Viper) mergeFoundMaps(lcaseKey string, found []interface{}) interface{} {
|
||||
if len(found) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var foundMaps []map[string]interface{}
|
||||
for _, fv := range found {
|
||||
fm, err := cast.ToStringMapE(fv)
|
||||
if err != nil {
|
||||
// A non-map value found. This shadows everything else
|
||||
// further down the list.
|
||||
break
|
||||
}
|
||||
foundMaps = append(foundMaps, fm)
|
||||
}
|
||||
|
||||
// No non-shadowed maps found. Return the highest priority value.
|
||||
if len(foundMaps) == 0 {
|
||||
return found[0]
|
||||
}
|
||||
|
||||
mergedMap := map[string]interface{}{}
|
||||
|
||||
// merge in reversed order (so that higher priority sources overwrite lower priority ones)
|
||||
for i := len(foundMaps) - 1; i >= 0; i = i - 1 {
|
||||
mergeMaps(foundMaps[i], mergedMap, nil)
|
||||
}
|
||||
|
||||
return mergedMap
|
||||
}
|
||||
|
||||
func readAsCSV(val string) ([]string, error) {
|
||||
|
@ -1401,8 +1441,8 @@ func IsSet(key string) bool { return v.IsSet(key) }
|
|||
|
||||
func (v *Viper) IsSet(key string) bool {
|
||||
lcaseKey := strings.ToLower(key)
|
||||
val := v.find(lcaseKey, false)
|
||||
return val != nil
|
||||
vals := v.find(lcaseKey, false)
|
||||
return len(vals) != 0
|
||||
}
|
||||
|
||||
// AutomaticEnv makes Viper check if environment variables match any of the existing keys
|
||||
|
|
|
@ -2569,6 +2569,23 @@ func TestSliceIndexAccess(t *testing.T) {
|
|||
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
|
||||
}
|
||||
|
||||
func TestMapWithDefaults(t *testing.T) {
|
||||
Set("config.value1", 1)
|
||||
SetDefault("config.value2.internal", 3)
|
||||
|
||||
m, ok := Get("config").(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, map[string]interface{}{"internal": 3}, m["value2"])
|
||||
}
|
||||
|
||||
func TestSubWithDefaults(t *testing.T) {
|
||||
Set("config.value1", 1)
|
||||
SetDefault("config.value2.internal", 3)
|
||||
|
||||
sub := Sub("config")
|
||||
assert.Equal(t, 3, sub.Get("value2.internal"))
|
||||
}
|
||||
|
||||
func BenchmarkGetBool(b *testing.B) {
|
||||
key := "BenchmarkGetBool"
|
||||
v = New()
|
||||
|
|
Loading…
Reference in a new issue