diff --git a/viper.go b/viper.go index 20eb4da..b658ffd 100644 --- a/viper.go +++ b/viper.go @@ -1128,7 +1128,83 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error { } // TODO: struct keys should be enough? - return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) + err := decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) + + // Post processing for slice of maps + // if features.BindStruct { + err = unmarshalPostProcess(rawVal, opts...) + if err != nil { + return err + } + // } + + return err +} + +func unmarshalPostProcess(input any, opts ...DecoderConfigOption) error { + var structKeyMap map[string]any + + err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) + if err != nil { + return err + } + + v.postProcessingSliceFields(map[string]bool{}, structKeyMap, "") + return nil +} + +// TODO remove shadow +func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]any, prefix string) map[string]bool { + if shadow != nil && prefix != "" && shadow[prefix] { + // prefix is shadowed => nothing more to flatten + return shadow + } + if shadow == nil { + shadow = make(map[string]bool) + } + + var m2 map[string]any + if prefix != "" { + prefix += v.keyDelim + } + for k, val := range m { + fullKey := prefix + k + valValue := reflect.ValueOf(val) + if valValue.Kind() == reflect.Slice { + for i := 0; i < valValue.Len(); i++ { + item := valValue.Index(i) + if item.Kind() != reflect.Struct || !item.CanSet() { + continue + } + itemType := item.Type() + for j := 0; j < item.NumField(); j++ { + field := itemType.Field(j) + // fmt.Printf("Field %d: Name=%s, Type=%v, Value=%v\n", j, field.Name, field.Type, item.Field(j).Interface()) + + sliceKey := fmt.Sprintf("%s%s%s%d%s%s", prefix, k, v.keyDelim, i, v.keyDelim, field.Name) + shadow[strings.ToLower(sliceKey)] = true + // fmt.Printf("%s is slice\n", sliceKey) + + if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { + // fmt.Printf("Val is %v\n", val) + item.Field(j).SetString(val) + } + } + } + } + + switch val := val.(type) { + case map[string]any: + m2 = val + case map[any]any: + m2 = cast.ToStringMap(val) + default: + continue + } + // recursively merge to shadow map + shadow = v.postProcessingSliceFields(shadow, m2, fullKey) + } + return shadow } func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { diff --git a/viper_test.go b/viper_test.go index 0b1f407..30460b0 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2606,6 +2606,77 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +var yamlSimpleSlice = []byte(` +name: Steve +port: 8080 +auth: + secret: 88888-88888 +clients: + - name: foo + - name: bar +proxy: + clients: + - name: proxy_foo + - name: proxy_bar + - name: proxy_baz +`) + +func TestSliceIndexAutomaticEnv(t *testing.T) { + v.SetConfigType("yaml") + r := strings.NewReader(string(yamlSimpleSlice)) + + type ClientConfig struct { + Name string + } + + type AuthConfig struct { + Secret string + } + + type ProxyConfig struct { + Clients []ClientConfig + } + + type Configuration struct { + Port int + Name string + Auth AuthConfig + Clients []ClientConfig + Proxy ProxyConfig + } + + // Read yaml as default value + err := v.unmarshalReader(r, v.config) + require.NoError(t, err) + + assert.Equal(t, "Steve", v.GetString("name")) + assert.Equal(t, 8080, v.GetInt("port")) + assert.Equal(t, "88888-88888", v.GetString("auth.secret")) + assert.Equal(t, "foo", v.GetString("clients.0.name")) + assert.Equal(t, "bar", v.GetString("clients.1.name")) + assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name")) + + // Override with env variable + t.Setenv("NAME", "Steven") + t.Setenv("AUTH_SECRET", "99999-99999") + t.Setenv("CLIENTS_1_NAME", "baz") + t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo") + + SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + AutomaticEnv() + + // Unmarshal into struct + var config Configuration + v.Unmarshal(&config) + + assert.Equal(t, "Steven", config.Name) + assert.Equal(t, 8080, config.Port) + assert.Equal(t, "99999-99999", config.Auth.Secret) + assert.Equal(t, "foo", config.Clients[0].Name) + assert.Equal(t, "baz", config.Clients[1].Name) + assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name) +} + func TestIsPathShadowedInFlatMap(t *testing.T) { v := New()