diff --git a/viper.go b/viper.go index 8889f87..82a88d7 100644 --- a/viper.go +++ b/viper.go @@ -205,7 +205,7 @@ type Viper struct { defaults map[string]interface{} kvstore map[string]interface{} pflags map[string]FlagValue - env map[string]string + env map[string][]string aliases map[string]string typeByDefValue bool @@ -228,7 +228,7 @@ func New() *Viper { v.defaults = make(map[string]interface{}) v.kvstore = make(map[string]interface{}) v.pflags = make(map[string]FlagValue) - v.env = make(map[string]string) + v.env = make(map[string][]string) v.aliases = make(map[string]string) v.typeByDefValue = false @@ -1029,21 +1029,18 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { func BindEnv(input ...string) error { return v.BindEnv(input...) } func (v *Viper) BindEnv(input ...string) error { - var key, envkey string if len(input) == 0 { return fmt.Errorf("missing key to bind to") } - key = strings.ToLower(input[0]) + key := strings.ToLower(input[0]) if len(input) == 1 { - envkey = v.mergeWithEnvPrefix(key) + v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) } else { - envkey = input[1] + v.env[key] = append(v.env[key], input[1:]...) } - v.env[key] = envkey - return nil } @@ -1122,10 +1119,12 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { return nil } } - envkey, exists := v.env[lcaseKey] + envkeys, exists := v.env[lcaseKey] if exists { - if val, ok := v.getEnv(envkey); ok { - return val + for _, envkey := range envkeys { + if val, ok := v.getEnv(envkey); ok { + return val + } } } if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { @@ -1711,6 +1710,14 @@ func castToMapStringInterface( return tgt } +func castMapStringSliceToMapInterface(src map[string][]string) map[string]interface{} { + tgt := map[string]interface{}{} + for k, v := range src { + tgt[k] = v + } + return tgt +} + func castMapStringToMapInterface(src map[string]string) map[string]interface{} { tgt := map[string]interface{}{} for k, v := range src { @@ -1883,7 +1890,7 @@ func (v *Viper) AllKeys() []string { m = v.flattenAndMergeMap(m, castMapStringToMapInterface(v.aliases), "") m = v.flattenAndMergeMap(m, v.override, "") m = v.mergeFlatMap(m, castMapFlagToMapInterface(v.pflags)) - m = v.mergeFlatMap(m, castMapStringToMapInterface(v.env)) + m = v.mergeFlatMap(m, castMapStringSliceToMapInterface(v.env)) m = v.flattenAndMergeMap(m, v.config, "") m = v.flattenAndMergeMap(m, v.kvstore, "") m = v.flattenAndMergeMap(m, v.defaults, "") diff --git a/viper_test.go b/viper_test.go index c38c665..28c7f13 100644 --- a/viper_test.go +++ b/viper_test.go @@ -486,10 +486,11 @@ func TestEnv(t *testing.T) { initJSON() BindEnv("id") - BindEnv("f", "FOOD") + BindEnv("f", "FOOD", "OLD_FOOD") os.Setenv("ID", "13") os.Setenv("FOOD", "apple") + os.Setenv("OLD_FOOD", "banana") os.Setenv("NAME", "crunk") assert.Equal(t, "13", Get("id")) @@ -501,6 +502,17 @@ func TestEnv(t *testing.T) { assert.Equal(t, "crunk", Get("name")) } +func TestMultipleEnv(t *testing.T) { + initJSON() + + BindEnv("f", "FOOD", "OLD_FOOD") + + os.Unsetenv("FOOD") + os.Setenv("OLD_FOOD", "banana") + + assert.Equal(t, "banana", Get("f")) +} + func TestEmptyEnv(t *testing.T) { initJSON()