diff --git a/overrides_test.go b/overrides_test.go index 42da3ba..f06ec9d 100644 --- a/overrides_test.go +++ b/overrides_test.go @@ -45,10 +45,10 @@ func TestNestedOverrides(t *testing.T) { deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4) // Case 4: key:value overridden by a map - v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} - assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable - assert.Equal(10, v.Get("tom.age")) // new value should be there - deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there + v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10, "size": 4}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10} + assert.Equal(4, v.Get("tom.size")) // "tom.size" should still be reachable + assert.Equal(10, v.Get("tom.age")) // new value should be there + deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10) // new value should be there v = override(assert, "tom.size", 4, "tom", map[string]any{"age": 10}) assert.Nil(v.Get("tom.size")) assert.Equal(10, v.Get("tom.age")) diff --git a/viper.go b/viper.go index 20eb4da..320d782 100644 --- a/viper.go +++ b/viper.go @@ -899,6 +899,13 @@ func GetViper() *Viper { // Get returns an interface. For a specific value use one of the Get____ methods. func Get(key string) any { return v.Get(key) } +func isStringMapInterface(val any) bool { + vt := reflect.TypeOf(val) + return vt.Kind() == reflect.Map && + vt.Key().Kind() == reflect.String && + vt.Elem().Kind() == reflect.Interface +} + func (v *Viper) Get(key string) any { lcaseKey := strings.ToLower(key) val := v.find(lcaseKey, true) @@ -906,6 +913,29 @@ func (v *Viper) Get(key string) any { return nil } + // when section is partially overridden, + // make sure to return the complete map. + if isStringMapInterface(val) { + val := val.(map[string]interface{}) + prefix := lcaseKey + v.keyDelim + keys := v.AllKeys() + for _, key := range keys { + if !strings.HasPrefix(key, prefix) { + continue + } + mk := strings.TrimPrefix(key, prefix) + mk = strings.Split(mk, v.keyDelim)[0] + if _, exists := val[mk]; exists { + continue + } + mv := v.Get(lcaseKey + v.keyDelim + mk) + if mv == nil { + continue + } + val[mk] = mv + } + } + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val