From 29dcff61f265fbfe3caa4d5fca302cbdbb173267 Mon Sep 17 00:00:00 2001 From: Arieh Schneier <15041913+AriehSchneier@users.noreply.github.com> Date: Thu, 5 Aug 2021 02:00:45 +1000 Subject: [PATCH] AllKeys() should include slices allowing overriding of keys with slices with environment variables --- viper.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++-- viper_test.go | 34 +++++++++++++++++++++++++++++--- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/viper.go b/viper.go index 46b1a85..329abd2 100644 --- a/viper.go +++ b/viper.go @@ -1035,8 +1035,8 @@ func (v *Viper) Unmarshal(rawVal interface{}, opts ...DecoderConfigOption) error return decode(v.AllSettings(), defaultDecoderConfig(rawVal, opts...)) } -// defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot -// of time.Duration values & string slices +// defaultDecoderConfig returns a default mapstructure.DecoderConfig with support +// for time.Duration values & string slices func defaultDecoderConfig(output interface{}, opts ...DecoderConfigOption) *mapstructure.DecoderConfig { c := &mapstructure.DecoderConfig{ Metadata: nil, @@ -1976,6 +1976,8 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac m2 = val.(map[string]interface{}) case map[interface{}]interface{}: m2 = cast.ToStringMap(val) + case []interface{}: + m2 = castSliceToStringMap(val.([]interface{})) default: // immediate value shadow[strings.ToLower(fullKey)] = true @@ -1987,6 +1989,17 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac return shadow } +// castSliceToStringMap converts a slice to a map where the keys are the indices. +func castSliceToStringMap(v []interface{}) map[string]interface{} { + m := make(map[string]interface{}, len(v)) + + for i, val := range v { + m[strconv.Itoa(i)] = val + } + + return m +} + // mergeFlatMap merges the given maps, excluding values of the second map // shadowed by values from the first map. func (v *Viper) mergeFlatMap(shadow map[string]bool, m map[string]interface{}) map[string]bool { @@ -2028,9 +2041,46 @@ func (v *Viper) AllSettings() map[string]interface{} { // set innermost value deepestMap[lastKey] = value } + + // convert any maps of integer keys back to slices + i := convertMapsToSlices(m) + if m2, ok := i.(map[string]interface{}); ok { + m = m2 + } + return m } +// convertMapsToSlices will do a deep check for any maps where the keys are all integers (as strings) +// as well as being contiguous from 0, and convert them to a slice. +func convertMapsToSlices(m map[string]interface{}) interface{} { + allInts := true + for k, v := range m { + if _, ok := strconv.Atoi(k); ok != nil { + allInts = false + } + + if m2, ok := v.(map[string]interface{}); ok { + m[k] = convertMapsToSlices(m2) + } + } + + if !allInts { + return m + } + + s := make([]interface{}, len(m)) + for i := 0; i < len(m); i++ { + v, ok := m[strconv.Itoa(i)] + if !ok { + return m + } + s[i] = v + } + + return s +} + // SetFs sets the filesystem to use to read configuration. func SetFs(fs afero.Fs) { v.SetFs(fs) } diff --git a/viper_test.go b/viper_test.go index 4192748..ee3bf2a 100644 --- a/viper_test.go +++ b/viper_test.go @@ -570,6 +570,29 @@ func TestAutoEnv(t *testing.T) { assert.Equal(t, "13", Get("foo_bar")) } +func TestAutoEnvWithSlice(t *testing.T) { + initJSON() + + AutomaticEnv() + + type config struct { + Batters struct { + Batter []struct { + Type string + } + } + } + + testutil.Setenv(t, "BATTERS.BATTER.1.TYPE", "Small") + + var C config + err := Unmarshal(&C) + if err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + assert.Equal(t, []struct{ Type string }{{"Regular"}, {"Small"}, {"Blueberry"}, {"Devil's Food"}}, C.Batters.Batter) +} + func TestAutoEnvWithPrefix(t *testing.T) { Reset() @@ -620,8 +643,13 @@ func TestAllKeys(t *testing.T) { "name", "beard", "ppu", - "batters.batter", - "hobbies", + "batters.batter.0.type", + "batters.batter.1.type", + "batters.batter.2.type", + "batters.batter.3.type", + "hobbies.0", + "hobbies.1", + "hobbies.2", "clothing.jacket", "clothing.trousers", "default.import_path", @@ -1983,7 +2011,7 @@ clothing: func TestDotParameter(t *testing.T) { initJSON() - // shoud take precedence over batters defined in jsonExample + // should take precedence over batters defined in jsonExample r := bytes.NewReader([]byte(`{ "batters.batter": [ { "type": "Small" } ] }`)) unmarshalReader(r, v.config)