mirror of
https://github.com/spf13/viper
synced 2025-01-23 19:06:38 +00:00
Merge defaults when Get()'ing maps
When fetching a configuration sub-tree (either via Get() that results in a map, or through Sub()), ensure that any defaults for keys under that tree are propagated into the returned result. Fixes #747 Signed-off-by: setrofim <setrofim@gmail.com>
This commit is contained in:
parent
f1d2c470bf
commit
1c3e461869
3 changed files with 92 additions and 35 deletions
|
@ -46,10 +46,10 @@ func TestNestedOverrides(t *testing.T) {
|
||||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
|
deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
|
||||||
|
|
||||||
// Case 4: key:value overridden by a map
|
// Case 4: key:value overridden by a map
|
||||||
v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
|
assert.Equal(map[string]interface{}{"size": 4, "age": 10}, v.Get("tom")) // "tom.size" is first given "4" as default value, then "tom" is set to map{"age":10}, size and age should be merged.
|
||||||
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable
|
||||||
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
assert.Equal(10, v.Get("tom.age")) // new value should be there
|
||||||
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there
|
||||||
v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
|
v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
|
||||||
assert.Nil(v.Get("tom.size"))
|
assert.Nil(v.Get("tom.size"))
|
||||||
assert.Equal(10, v.Get("tom.age"))
|
assert.Equal(10, v.Get("tom.age"))
|
||||||
|
|
102
viper.go
102
viper.go
|
@ -886,11 +886,14 @@ func Get(key string) interface{} { return v.Get(key) }
|
||||||
|
|
||||||
func (v *Viper) Get(key string) interface{} {
|
func (v *Viper) Get(key string) interface{} {
|
||||||
lcaseKey := strings.ToLower(key)
|
lcaseKey := strings.ToLower(key)
|
||||||
val := v.find(lcaseKey, true)
|
|
||||||
if val == nil {
|
found := v.find(lcaseKey, true)
|
||||||
|
if len(found) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val := v.mergeFoundMaps(lcaseKey, found)
|
||||||
|
|
||||||
if v.typeByDefValue {
|
if v.typeByDefValue {
|
||||||
// TODO(bep) this branch isn't covered by a single test.
|
// TODO(bep) this branch isn't covered by a single test.
|
||||||
valType := val
|
valType := val
|
||||||
|
@ -1226,9 +1229,10 @@ func (v *Viper) MustBindEnv(input ...string) {
|
||||||
// corresponds to a flag, the flag's default value is returned.
|
// corresponds to a flag, the flag's default value is returned.
|
||||||
//
|
//
|
||||||
// Note: this assumes a lower-cased key given.
|
// Note: this assumes a lower-cased key given.
|
||||||
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
func (v *Viper) find(lcaseKey string, flagDefault bool) []interface{} {
|
||||||
var (
|
var (
|
||||||
val interface{}
|
val interface{}
|
||||||
|
found []interface{}
|
||||||
exists bool
|
exists bool
|
||||||
path = strings.Split(lcaseKey, v.keyDelim)
|
path = strings.Split(lcaseKey, v.keyDelim)
|
||||||
nested = len(path) > 1
|
nested = len(path) > 1
|
||||||
|
@ -1247,10 +1251,10 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
||||||
// Set() override first
|
// Set() override first
|
||||||
val = v.searchMap(v.override, path)
|
val = v.searchMap(v.override, path)
|
||||||
if val != nil {
|
if val != nil {
|
||||||
return val
|
found = append(found, val)
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// PFlag override next
|
// PFlag override next
|
||||||
|
@ -1258,27 +1262,27 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
||||||
if exists && flag.HasChanged() {
|
if exists && flag.HasChanged() {
|
||||||
switch flag.ValueType() {
|
switch flag.ValueType() {
|
||||||
case "int", "int8", "int16", "int32", "int64":
|
case "int", "int8", "int16", "int32", "int64":
|
||||||
return cast.ToInt(flag.ValueString())
|
found = append(found, cast.ToInt(flag.ValueString()))
|
||||||
case "bool":
|
case "bool":
|
||||||
return cast.ToBool(flag.ValueString())
|
found = append(found, cast.ToBool(flag.ValueString()))
|
||||||
case "stringSlice", "stringArray":
|
case "stringSlice", "stringArray":
|
||||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||||
s = strings.TrimSuffix(s, "]")
|
s = strings.TrimSuffix(s, "]")
|
||||||
res, _ := readAsCSV(s)
|
res, _ := readAsCSV(s)
|
||||||
return res
|
found = append(found, res)
|
||||||
case "intSlice":
|
case "intSlice":
|
||||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||||
s = strings.TrimSuffix(s, "]")
|
s = strings.TrimSuffix(s, "]")
|
||||||
res, _ := readAsCSV(s)
|
res, _ := readAsCSV(s)
|
||||||
return cast.ToIntSlice(res)
|
found = append(found, cast.ToIntSlice(res))
|
||||||
case "stringToString":
|
case "stringToString":
|
||||||
return stringToStringConv(flag.ValueString())
|
found = append(found, stringToStringConv(flag.ValueString()))
|
||||||
default:
|
default:
|
||||||
return flag.ValueString()
|
found = append(found, flag.ValueString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// Env override next
|
// Env override next
|
||||||
|
@ -1286,49 +1290,50 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
||||||
// even if it hasn't been registered, if automaticEnv is used,
|
// even if it hasn't been registered, if automaticEnv is used,
|
||||||
// check any Get request
|
// check any Get request
|
||||||
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
||||||
return val
|
found = append(found, val)
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
envkeys, exists := v.env[lcaseKey]
|
envkeys, exists := v.env[lcaseKey]
|
||||||
if exists {
|
if exists {
|
||||||
for _, envkey := range envkeys {
|
for _, envkey := range envkeys {
|
||||||
if val, ok := v.getEnv(envkey); ok {
|
if val, ok := v.getEnv(envkey); ok {
|
||||||
return val
|
found = append(found, val)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config file next
|
// Config file next
|
||||||
val = v.searchIndexableWithPathPrefixes(v.config, path)
|
val = v.searchIndexableWithPathPrefixes(v.config, path)
|
||||||
if val != nil {
|
if val != nil {
|
||||||
return val
|
found = append(found, val)
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// K/V store next
|
// K/V store next
|
||||||
val = v.searchMap(v.kvstore, path)
|
val = v.searchMap(v.kvstore, path)
|
||||||
if val != nil {
|
if val != nil {
|
||||||
return val
|
found = append(found, val)
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default next
|
// Default next
|
||||||
val = v.searchMap(v.defaults, path)
|
val = v.searchMap(v.defaults, path)
|
||||||
if val != nil {
|
if val != nil {
|
||||||
return val
|
found = append(found, val)
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
if flagDefault {
|
if flagDefault {
|
||||||
|
@ -1337,29 +1342,64 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
|
||||||
if flag, exists := v.pflags[lcaseKey]; exists {
|
if flag, exists := v.pflags[lcaseKey]; exists {
|
||||||
switch flag.ValueType() {
|
switch flag.ValueType() {
|
||||||
case "int", "int8", "int16", "int32", "int64":
|
case "int", "int8", "int16", "int32", "int64":
|
||||||
return cast.ToInt(flag.ValueString())
|
found = append(found, cast.ToInt(flag.ValueString()))
|
||||||
case "bool":
|
case "bool":
|
||||||
return cast.ToBool(flag.ValueString())
|
found = append(found, cast.ToBool(flag.ValueString()))
|
||||||
case "stringSlice", "stringArray":
|
case "stringSlice", "stringArray":
|
||||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||||
s = strings.TrimSuffix(s, "]")
|
s = strings.TrimSuffix(s, "]")
|
||||||
res, _ := readAsCSV(s)
|
res, _ := readAsCSV(s)
|
||||||
return res
|
found = append(found, res)
|
||||||
case "intSlice":
|
case "intSlice":
|
||||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||||
s = strings.TrimSuffix(s, "]")
|
s = strings.TrimSuffix(s, "]")
|
||||||
res, _ := readAsCSV(s)
|
res, _ := readAsCSV(s)
|
||||||
return cast.ToIntSlice(res)
|
found = append(found, cast.ToIntSlice(res))
|
||||||
case "stringToString":
|
case "stringToString":
|
||||||
return stringToStringConv(flag.ValueString())
|
found = append(found, stringToStringConv(flag.ValueString()))
|
||||||
default:
|
default:
|
||||||
return flag.ValueString()
|
found = append(found, flag.ValueString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// last item, no need to check shadowing
|
// last item, no need to check shadowing
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// found is a slice of values found across layers for lcaseKey in priority
|
||||||
|
// order (i.e. override, flag, env, config file, key/value store, default). If
|
||||||
|
// the highest priority value found is a map, traverse back across found
|
||||||
|
// values, and, while they are "mappable", merged them into the overall result.
|
||||||
|
func (v *Viper) mergeFoundMaps(lcaseKey string, found []interface{}) interface{} {
|
||||||
|
if len(found) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var foundMaps []map[string]interface{}
|
||||||
|
for _, fv := range found {
|
||||||
|
fm, err := cast.ToStringMapE(fv)
|
||||||
|
if err != nil {
|
||||||
|
// A non-map value found. This shadows everything else
|
||||||
|
// further down the list.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
foundMaps = append(foundMaps, fm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No non-shadowed maps found. Return the highest priority value.
|
||||||
|
if len(foundMaps) == 0 {
|
||||||
|
return found[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
mergedMap := map[string]interface{}{}
|
||||||
|
|
||||||
|
// merge in reversed order (so that higher priority sources overwrite lower priority ones)
|
||||||
|
for i := len(foundMaps) - 1; i >= 0; i = i - 1 {
|
||||||
|
mergeMaps(foundMaps[i], mergedMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergedMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func readAsCSV(val string) ([]string, error) {
|
func readAsCSV(val string) ([]string, error) {
|
||||||
|
@ -1401,8 +1441,8 @@ func IsSet(key string) bool { return v.IsSet(key) }
|
||||||
|
|
||||||
func (v *Viper) IsSet(key string) bool {
|
func (v *Viper) IsSet(key string) bool {
|
||||||
lcaseKey := strings.ToLower(key)
|
lcaseKey := strings.ToLower(key)
|
||||||
val := v.find(lcaseKey, false)
|
vals := v.find(lcaseKey, false)
|
||||||
return val != nil
|
return len(vals) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutomaticEnv makes Viper check if environment variables match any of the existing keys
|
// AutomaticEnv makes Viper check if environment variables match any of the existing keys
|
||||||
|
|
|
@ -2569,6 +2569,23 @@ func TestSliceIndexAccess(t *testing.T) {
|
||||||
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
|
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMapWithDefaults(t *testing.T) {
|
||||||
|
Set("config.value1", 1)
|
||||||
|
SetDefault("config.value2.internal", 3)
|
||||||
|
|
||||||
|
m, ok := Get("config").(map[string]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, map[string]interface{}{"internal": 3}, m["value2"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubWithDefaults(t *testing.T) {
|
||||||
|
Set("config.value1", 1)
|
||||||
|
SetDefault("config.value2.internal", 3)
|
||||||
|
|
||||||
|
sub := Sub("config")
|
||||||
|
assert.Equal(t, 3, sub.Get("value2.internal"))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGetBool(b *testing.B) {
|
func BenchmarkGetBool(b *testing.B) {
|
||||||
key := "BenchmarkGetBool"
|
key := "BenchmarkGetBool"
|
||||||
v = New()
|
v = New()
|
||||||
|
|
Loading…
Reference in a new issue