add unit tests to catch a bug/gap where slice values are not overriden by env variables

This commit is contained in:
Jason Lee 2024-01-03 12:17:45 +08:00
parent e36638d878
commit c0546f7419
2 changed files with 148 additions and 1 deletions

View file

@ -1128,7 +1128,83 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error {
} }
// TODO: struct keys should be enough? // TODO: struct keys should be enough?
return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) err := decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...))
// Post processing for slice of maps
// if features.BindStruct {
err = unmarshalPostProcess(rawVal, opts...)
if err != nil {
return err
}
// }
return err
}
func unmarshalPostProcess(input any, opts ...DecoderConfigOption) error {
var structKeyMap map[string]any
err := decode(input, defaultDecoderConfig(&structKeyMap, opts...))
if err != nil {
return err
}
v.postProcessingSliceFields(map[string]bool{}, structKeyMap, "")
return nil
}
// TODO remove shadow
func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]any, prefix string) map[string]bool {
if shadow != nil && prefix != "" && shadow[prefix] {
// prefix is shadowed => nothing more to flatten
return shadow
}
if shadow == nil {
shadow = make(map[string]bool)
}
var m2 map[string]any
if prefix != "" {
prefix += v.keyDelim
}
for k, val := range m {
fullKey := prefix + k
valValue := reflect.ValueOf(val)
if valValue.Kind() == reflect.Slice {
for i := 0; i < valValue.Len(); i++ {
item := valValue.Index(i)
if item.Kind() != reflect.Struct || !item.CanSet() {
continue
}
itemType := item.Type()
for j := 0; j < item.NumField(); j++ {
field := itemType.Field(j)
// fmt.Printf("Field %d: Name=%s, Type=%v, Value=%v\n", j, field.Name, field.Type, item.Field(j).Interface())
sliceKey := fmt.Sprintf("%s%s%s%d%s%s", prefix, k, v.keyDelim, i, v.keyDelim, field.Name)
shadow[strings.ToLower(sliceKey)] = true
// fmt.Printf("%s is slice\n", sliceKey)
if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok {
// fmt.Printf("Val is %v\n", val)
item.Field(j).SetString(val)
}
}
}
}
switch val := val.(type) {
case map[string]any:
m2 = val
case map[any]any:
m2 = cast.ToStringMap(val)
default:
continue
}
// recursively merge to shadow map
shadow = v.postProcessingSliceFields(shadow, m2, fullKey)
}
return shadow
} }
func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) {

View file

@ -2606,6 +2606,77 @@ func TestSliceIndexAccess(t *testing.T) {
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
} }
var yamlSimpleSlice = []byte(`
name: Steve
port: 8080
auth:
secret: 88888-88888
clients:
- name: foo
- name: bar
proxy:
clients:
- name: proxy_foo
- name: proxy_bar
- name: proxy_baz
`)
func TestSliceIndexAutomaticEnv(t *testing.T) {
v.SetConfigType("yaml")
r := strings.NewReader(string(yamlSimpleSlice))
type ClientConfig struct {
Name string
}
type AuthConfig struct {
Secret string
}
type ProxyConfig struct {
Clients []ClientConfig
}
type Configuration struct {
Port int
Name string
Auth AuthConfig
Clients []ClientConfig
Proxy ProxyConfig
}
// Read yaml as default value
err := v.unmarshalReader(r, v.config)
require.NoError(t, err)
assert.Equal(t, "Steve", v.GetString("name"))
assert.Equal(t, 8080, v.GetInt("port"))
assert.Equal(t, "88888-88888", v.GetString("auth.secret"))
assert.Equal(t, "foo", v.GetString("clients.0.name"))
assert.Equal(t, "bar", v.GetString("clients.1.name"))
assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name"))
// Override with env variable
t.Setenv("NAME", "Steven")
t.Setenv("AUTH_SECRET", "99999-99999")
t.Setenv("CLIENTS_1_NAME", "baz")
t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo")
SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
AutomaticEnv()
// Unmarshal into struct
var config Configuration
v.Unmarshal(&config)
assert.Equal(t, "Steven", config.Name)
assert.Equal(t, 8080, config.Port)
assert.Equal(t, "99999-99999", config.Auth.Secret)
assert.Equal(t, "foo", config.Clients[0].Name)
assert.Equal(t, "baz", config.Clients[1].Name)
assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name)
}
func TestIsPathShadowedInFlatMap(t *testing.T) { func TestIsPathShadowedInFlatMap(t *testing.T) {
v := New() v := New()