Add CaseSensitiveKeys option

When reading configuration from sources with case-sensitive keys,
such as YAML, TOML, and JSON, a user may wish to preserve the case
of keys that appear in maps.  For example, consider when the value
of a setting is a map with string keys that are case-sensitive.
Ideally, if the value is not going to be indexed by a Viper lookup
key, then the map value should be treated as an opaque value by
Viper, and its keys should not be modified.  See #1014

Viper's default behaviour is that keys are case-sensitive, and this
behavior is implemented by converting all keys to lower-case.  For
users that wish to preserve the case of keys, this commit introduces
an Option `CaseSensitiveKeys()` that can be used to configure Viper
to use case-sensitive keys.  When CaseSensitiveKeys is enabled, all
keys retain the original case, and lookups become case-sensitive
(except for lookups of values bound to environment variables).

The behavior of Viper could become hard to understand if a user
could change the CaseSensitiveKeys setting after values have been
stored.  For this reason, the setting may only be set when creating
a Viper instance, and it cannot be set on the "global" Viper.
This commit is contained in:
Travis Newhouse 2023-10-20 15:09:25 -07:00
parent b5daec6e7b
commit 21336ce35f
5 changed files with 556 additions and 82 deletions

43
util.go
View file

@ -37,39 +37,42 @@ func (pe ConfigParseError) Unwrap() error {
return pe.err return pe.err
} }
// toCaseInsensitiveValue checks if the value is a map; // CopyMap returns a deep copy of a map[any]any or map[string]any. If value is
// if so, create a copy and lower-case the keys recursively. // not one of those map types, then it is returned as-is. If preserveCase is
func toCaseInsensitiveValue(value any) any { // false, then all keys will be converted to lower-case in the copy that is
switch v := value.(type) { // returned.
case map[any]any: func CopyMap(value any, preserveCase bool) any {
value = copyAndInsensitiviseMap(cast.ToStringMap(v)) var copyMap func(map[string]any, bool) map[string]any
case map[string]any: copyMap = func(m map[string]any, preserveCase bool) map[string]any {
value = copyAndInsensitiviseMap(v)
}
return value
}
// copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of
// any map it makes case insensitive.
func copyAndInsensitiviseMap(m map[string]any) map[string]any {
nm := make(map[string]any) nm := make(map[string]any)
for key, val := range m { for key, val := range m {
lkey := strings.ToLower(key) if !preserveCase {
key = strings.ToLower(key)
}
switch v := val.(type) { switch v := val.(type) {
case map[any]any: case map[any]any:
nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) nm[key] = copyMap(cast.ToStringMap(v), preserveCase)
case map[string]any: case map[string]any:
nm[lkey] = copyAndInsensitiviseMap(v) nm[key] = copyMap(v, preserveCase)
default: default:
nm[lkey] = v nm[key] = v
} }
} }
return nm return nm
} }
switch v := value.(type) {
case map[any]any:
value = copyMap(cast.ToStringMap(v), preserveCase)
case map[string]any:
value = copyMap(v, preserveCase)
}
return value
}
func insensitiviseVal(val any) any { func insensitiviseVal(val any) any {
switch v := val.(type) { switch v := val.(type) {
case map[any]any: case map[any]any:

View file

@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCopyAndInsensitiviseMap(t *testing.T) { func TestCopyMap(t *testing.T) {
var ( var (
given = map[string]any{ given = map[string]any{
"Foo": 32, "Foo": 32,
@ -35,9 +35,17 @@ func TestCopyAndInsensitiviseMap(t *testing.T) {
"cde": "B", "cde": "B",
}, },
} }
expectedPreserveCase = map[string]any{
"Foo": 32,
"Bar": map[string]any{
"ABc": "A",
"cDE": "B",
},
}
) )
got := copyAndInsensitiviseMap(given) t.Run("convert to lower-case", func(t *testing.T) {
got := CopyMap(given, false)
assert.Equal(t, expected, got) assert.Equal(t, expected, got)
_, ok := given["foo"] _, ok := given["foo"]
@ -48,6 +56,33 @@ func TestCopyAndInsensitiviseMap(t *testing.T) {
m := given["Bar"].(map[any]any) m := given["Bar"].(map[any]any)
_, ok = m["ABc"] _, ok = m["ABc"]
assert.True(t, ok) assert.True(t, ok)
})
t.Run("preserve case", func(t *testing.T) {
got := CopyMap(given, true)
assert.Equal(t, expectedPreserveCase, got)
_, ok := given["foo"]
assert.False(t, ok)
_, ok = given["bar"]
assert.False(t, ok)
m := given["Bar"].(map[any]any)
_, ok = m["ABc"]
assert.True(t, ok)
})
t.Run("not a map", func(t *testing.T) {
var (
given = []any{42, "xyz"}
expected = []any{42, "xyz"}
)
got := CopyMap(given, false)
assert.Equal(t, expected, got)
got = CopyMap(given, true)
assert.Equal(t, expected, got)
})
} }
func TestAbsPathify(t *testing.T) { func TestAbsPathify(t *testing.T) {

115
viper.go
View file

@ -206,6 +206,10 @@ type Viper struct {
envKeyReplacer StringReplacer envKeyReplacer StringReplacer
allowEmptyEnv bool allowEmptyEnv bool
// When caseSensitiveKeys is true, keys are preserved in their original
// case (i.e., not modified to lower-case).
caseSensitiveKeys bool
parents []string parents []string
config map[string]any config map[string]any
override map[string]any override map[string]any
@ -283,6 +287,20 @@ func EnvKeyReplacer(r StringReplacer) Option {
}) })
} }
// CaseSensitiveKeys sets Viper to use case-sensitive keys, which preserves the
// case of keys. By default, all keys are converted to lower-case.
//
// Lookup keys (i.e., Get()) will always match environment variables in a
// case-insensitive manner because Viper always converts the lookup key to
// upper-case when searching for an environment variable. The case of the
// actual keys will be preserved, however, as seen in the output of AllKeys(),
// AllSettings(), etc.
func CaseSensitiveKeys(enable bool) Option {
return optionFunc(func(v *Viper) {
v.caseSensitiveKeys = enable
})
}
// NewWithOptions creates a new Viper instance. // NewWithOptions creates a new Viper instance.
func NewWithOptions(opts ...Option) *Viper { func NewWithOptions(opts ...Option) *Viper {
v := New() v := New()
@ -706,7 +724,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source any, path []string) any {
// search for path prefixes, starting from the longest one // search for path prefixes, starting from the longest one
for i := len(path); i > 0; i-- { for i := len(path); i > 0; i-- {
prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) prefixKey := v.toLower(strings.Join(path[0:i], v.keyDelim))
var val any var val any
switch sourceIndexable := source.(type) { switch sourceIndexable := source.(type) {
@ -897,8 +915,7 @@ func GetViper() *Viper {
func Get(key string) any { return v.Get(key) } func Get(key string) any { return v.Get(key) }
func (v *Viper) Get(key string) any { func (v *Viper) Get(key string) any {
lcaseKey := strings.ToLower(key) val := v.find(v.toLower(key), true)
val := v.find(lcaseKey, true)
if val == nil { if val == nil {
return nil return nil
} }
@ -906,7 +923,7 @@ func (v *Viper) Get(key string) any {
if v.typeByDefValue { if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test. // TODO(bep) this branch isn't covered by a single test.
valType := val valType := val
path := strings.Split(lcaseKey, v.keyDelim) path := strings.Split(key, v.keyDelim)
defVal := v.searchMap(v.defaults, path) defVal := v.searchMap(v.defaults, path)
if defVal != nil { if defVal != nil {
valType = defVal valType = defVal
@ -950,14 +967,14 @@ func (v *Viper) Get(key string) any {
func Sub(key string) *Viper { return v.Sub(key) } func Sub(key string) *Viper { return v.Sub(key) }
func (v *Viper) Sub(key string) *Viper { func (v *Viper) Sub(key string) *Viper {
subv := New()
data := v.Get(key) data := v.Get(key)
if data == nil { if data == nil {
return nil return nil
} }
if reflect.TypeOf(data).Kind() == reflect.Map { if reflect.TypeOf(data).Kind() == reflect.Map {
subv.parents = append(v.parents, strings.ToLower(key)) subv := NewWithOptions(CaseSensitiveKeys(v.caseSensitiveKeys))
subv.parents = append(v.parents, v.toLower(key))
subv.automaticEnvApplied = v.automaticEnvApplied subv.automaticEnvApplied = v.automaticEnvApplied
subv.envPrefix = v.envPrefix subv.envPrefix = v.envPrefix
subv.envKeyReplacer = v.envKeyReplacer subv.envKeyReplacer = v.envKeyReplacer
@ -1196,7 +1213,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
if flag == nil { if flag == nil {
return fmt.Errorf("flag for %q is nil", key) return fmt.Errorf("flag for %q is nil", key)
} }
v.pflags[strings.ToLower(key)] = flag v.pflags[v.toLower(key)] = flag
return nil return nil
} }
@ -1213,7 +1230,7 @@ func (v *Viper) BindEnv(input ...string) error {
return fmt.Errorf("missing key to bind to") return fmt.Errorf("missing key to bind to")
} }
key := strings.ToLower(input[0]) key := v.toLower(input[0])
if len(input) == 1 { if len(input) == 1 {
v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key))
@ -1243,12 +1260,13 @@ func (v *Viper) MustBindEnv(input ...string) {
// Lastly, if no value was found and flagDefault is true, and if the key // Lastly, if no value was found and flagDefault is true, and if the key
// corresponds to a flag, the flag's default value is returned. // corresponds to a flag, the flag's default value is returned.
// //
// Note: this assumes a lower-cased key given. // Note: By default, this assumes that a lowercase key is given.
func (v *Viper) find(lcaseKey string, flagDefault bool) any { // This behavior can be modified with viper.SetPreserveCase().
func (v *Viper) find(key string, flagDefault bool) any {
var ( var (
val any val any
exists bool exists bool
path = strings.Split(lcaseKey, v.keyDelim) path = strings.Split(key, v.keyDelim)
nested = len(path) > 1 nested = len(path) > 1
) )
@ -1258,8 +1276,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any {
} }
// if the requested key is an alias, then return the proper key // if the requested key is an alias, then return the proper key
lcaseKey = v.realKey(lcaseKey) key = v.realKey(key)
path = strings.Split(lcaseKey, v.keyDelim) path = strings.Split(key, v.keyDelim)
nested = len(path) > 1 nested = len(path) > 1
// Set() override first // Set() override first
@ -1272,7 +1290,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any {
} }
// PFlag override next // PFlag override next
flag, exists := v.pflags[lcaseKey] flag, exists := v.pflags[key]
if exists && flag.HasChanged() { if exists && flag.HasChanged() {
switch flag.ValueType() { switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
@ -1308,7 +1326,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any {
// Env override next // Env override next
if v.automaticEnvApplied { if v.automaticEnvApplied {
envKey := strings.Join(append(v.parents, lcaseKey), ".") envKey := strings.Join(append(v.parents, key), ".")
// even if it hasn't been registered, if automaticEnv is used, // even if it hasn't been registered, if automaticEnv is used,
// check any Get request // check any Get request
if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok {
@ -1318,7 +1336,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any {
return nil return nil
} }
} }
envkeys, exists := v.env[lcaseKey] envkeys, exists := v.env[key]
if exists { if exists {
for _, envkey := range envkeys { for _, envkey := range envkeys {
if val, ok := v.getEnv(envkey); ok { if val, ok := v.getEnv(envkey); ok {
@ -1360,7 +1378,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any {
if flagDefault { if flagDefault {
// last chance: if no value is found and a flag does exist for the key, // last chance: if no value is found and a flag does exist for the key,
// get the flag's default value even if the flag's value has not been set. // get the flag's default value even if the flag's value has not been set.
if flag, exists := v.pflags[lcaseKey]; exists { if flag, exists := v.pflags[key]; exists {
switch flag.ValueType() { switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString()) return cast.ToInt(flag.ValueString())
@ -1457,8 +1475,7 @@ func stringToIntConv(val string) any {
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool { func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key) val := v.find(v.toLower(key), false)
val := v.find(lcaseKey, false)
return val != nil return val != nil
} }
@ -1484,11 +1501,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) {
func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) }
func (v *Viper) RegisterAlias(alias string, key string) { func (v *Viper) RegisterAlias(alias string, key string) {
v.registerAlias(alias, strings.ToLower(key)) v.registerAlias(alias, v.toLower(key))
} }
func (v *Viper) registerAlias(alias string, key string) { func (v *Viper) registerAlias(alias string, key string) {
alias = strings.ToLower(alias) alias = v.toLower(alias)
if alias != key && alias != v.realKey(key) { if alias != key && alias != v.realKey(key) {
_, exists := v.aliases[alias] _, exists := v.aliases[alias]
@ -1533,11 +1550,9 @@ func (v *Viper) realKey(key string) string {
func InConfig(key string) bool { return v.InConfig(key) } func InConfig(key string) bool { return v.InConfig(key) }
func (v *Viper) InConfig(key string) bool { func (v *Viper) InConfig(key string) bool {
lcaseKey := strings.ToLower(key)
// if the requested key is an alias, then return the proper key // if the requested key is an alias, then return the proper key
lcaseKey = v.realKey(lcaseKey) key = v.realKey(v.toLower(key))
path := strings.Split(lcaseKey, v.keyDelim) path := strings.Split(key, v.keyDelim)
return v.searchIndexableWithPathPrefixes(v.config, path) != nil return v.searchIndexableWithPathPrefixes(v.config, path) != nil
} }
@ -1549,11 +1564,11 @@ func SetDefault(key string, value any) { v.SetDefault(key, value) }
func (v *Viper) SetDefault(key string, value any) { func (v *Viper) SetDefault(key string, value any) {
// If alias passed in, then set the proper default // If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key)) key = v.realKey(v.toLower(key))
value = toCaseInsensitiveValue(value) value = CopyMap(value, v.caseSensitiveKeys)
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.toLower(path[len(path)-1])
deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1567,12 +1582,12 @@ func (v *Viper) SetDefault(key string, value any) {
func Set(key string, value any) { v.Set(key, value) } func Set(key string, value any) { v.Set(key, value) }
func (v *Viper) Set(key string, value any) { func (v *Viper) Set(key string, value any) {
// If alias passed in, then set the proper override // If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key)) key = v.realKey(v.toLower(key))
value = toCaseInsensitiveValue(value) value = CopyMap(value, v.caseSensitiveKeys)
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.toLower(path[len(path)-1])
deepestMap := deepSearch(v.override, path[0:len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1661,8 +1676,10 @@ func (v *Viper) MergeConfigMap(cfg map[string]any) error {
if v.config == nil { if v.config == nil {
v.config = make(map[string]any) v.config = make(map[string]any)
} }
if !v.caseSensitiveKeys {
insensitiviseMap(cfg) insensitiviseMap(cfg)
mergeMaps(cfg, v.config, nil) }
v.mergeMaps(cfg, v.config, nil)
return nil return nil
} }
@ -1761,7 +1778,9 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]any) error {
} }
} }
if !v.caseSensitiveKeys {
insensitiviseMap(c) insensitiviseMap(c)
}
return nil return nil
} }
@ -1783,10 +1802,10 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error {
return nil return nil
} }
func keyExists(k string, m map[string]any) string { func (v *Viper) keyExists(k string, m map[string]any) string {
lk := strings.ToLower(k) lk := v.toLower(k)
for mk := range m { for mk := range m {
lmk := strings.ToLower(mk) lmk := v.toLower(mk)
if lmk == lk { if lmk == lk {
return mk return mk
} }
@ -1833,9 +1852,9 @@ func castMapFlagToMapInterface(src map[string]FlagValue) map[string]any {
// instead of using a `string` as the key for nest structures beyond one level // instead of using a `string` as the key for nest structures beyond one level
// deep. Both map types are supported as there is a go-yaml fork that uses // deep. Both map types are supported as there is a go-yaml fork that uses
// `map[string]any` instead. // `map[string]any` instead.
func mergeMaps(src, tgt map[string]any, itgt map[any]any) { func (v *Viper) mergeMaps(src, tgt map[string]any, itgt map[any]any) {
for sk, sv := range src { for sk, sv := range src {
tk := keyExists(sk, tgt) tk := v.keyExists(sk, tgt)
if tk == "" { if tk == "" {
v.logger.Debug("", "tk", "\"\"", fmt.Sprintf("tgt[%s]", sk), sv) v.logger.Debug("", "tk", "\"\"", fmt.Sprintf("tgt[%s]", sk), sv)
tgt[sk] = sv tgt[sk] = sv
@ -1885,7 +1904,7 @@ func mergeMaps(src, tgt map[string]any, itgt map[any]any) {
ssv := castToMapStringInterface(tsv) ssv := castToMapStringInterface(tsv)
stv := castToMapStringInterface(ttv) stv := castToMapStringInterface(ttv)
mergeMaps(ssv, stv, ttv) v.mergeMaps(ssv, stv, ttv)
case map[string]any: case map[string]any:
v.logger.Debug("merging maps") v.logger.Debug("merging maps")
tsv, ok := sv.(map[string]any) tsv, ok := sv.(map[string]any)
@ -1900,7 +1919,7 @@ func mergeMaps(src, tgt map[string]any, itgt map[any]any) {
) )
continue continue
} }
mergeMaps(tsv, ttv, nil) v.mergeMaps(tsv, ttv, nil)
default: default:
v.logger.Debug("setting value") v.logger.Debug("setting value")
tgt[tk] = sv tgt[tk] = sv
@ -2063,7 +2082,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]any, pre
m2 = cast.ToStringMap(val) m2 = cast.ToStringMap(val)
default: default:
// immediate value // immediate value
shadow[strings.ToLower(fullKey)] = true shadow[v.toLower(fullKey)] = true
continue continue
} }
// recursively merge to shadow map // recursively merge to shadow map
@ -2089,7 +2108,7 @@ outer:
} }
} }
// add key // add key
shadow[strings.ToLower(k)] = true shadow[v.toLower(k)] = true
} }
return shadow return shadow
} }
@ -2108,7 +2127,7 @@ func (v *Viper) AllSettings() map[string]any {
continue continue
} }
path := strings.Split(k, v.keyDelim) path := strings.Split(k, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.toLower(path[len(path)-1])
deepestMap := deepSearch(m, path[0:len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1])
// set innermost value // set innermost value
deepestMap[lastKey] = value deepestMap[lastKey] = value
@ -2204,3 +2223,13 @@ func (v *Viper) DebugTo(w io.Writer) {
fmt.Fprintf(w, "Config:\n%#v\n", v.config) fmt.Fprintf(w, "Config:\n%#v\n", v.config)
fmt.Fprintf(w, "Defaults:\n%#v\n", v.defaults) fmt.Fprintf(w, "Defaults:\n%#v\n", v.defaults)
} }
// toLower returns a properly cased key based on the CaseSensitiveKeys option.
// If preserveCase is true, then the unmodifed key is returned. Otherwise, the
// lower-cased key is returned.
func (v *Viper) toLower(k string) string {
if v.caseSensitiveKeys {
return k
}
return strings.ToLower(k)
}

View file

@ -171,6 +171,41 @@ func initConfigs() {
unmarshalReader(r, v.config) unmarshalReader(r, v.config)
} }
func initAllConfigs(v *Viper) {
var r io.Reader
v.SetConfigType("yaml")
r = bytes.NewReader(yamlExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("json")
r = bytes.NewReader(jsonExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("hcl")
r = bytes.NewReader(hclExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("properties")
r = bytes.NewReader(propertiesExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("toml")
r = bytes.NewReader(tomlExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("env")
r = bytes.NewReader(dotenvExample)
v.unmarshalReader(r, v.config)
v.SetConfigType("json")
remote := bytes.NewReader(remoteExample)
v.unmarshalReader(remote, v.kvstore)
v.SetConfigType("ini")
r = bytes.NewReader(iniExample)
v.unmarshalReader(r, v.config)
}
func initConfig(typ, config string) { func initConfig(typ, config string) {
Reset() Reset()
SetConfigType(typ) SetConfigType(typ)
@ -482,6 +517,23 @@ func TestDefault(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "leather", Get("clothing.jacket")) assert.Equal(t, "leather", Get("clothing.jacket"))
assert.Equal(t, "leather", Get("clothing.Jacket"))
}
func TestDefault_CaseSensitive(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetDefault("age", 45)
assert.Equal(t, 45, v.Get("age"))
v.SetDefault("clothing.jacket", "slacks")
assert.Equal(t, "slacks", v.Get("clothing.jacket"))
v.SetConfigType("yaml")
err := v.ReadConfig(bytes.NewBuffer(yamlExample))
assert.NoError(t, err)
assert.Equal(t, "slacks", v.Get("clothing.jacket"))
assert.Equal(t, "leather", v.Get("clothing.Jacket"))
} }
func TestUnmarshaling(t *testing.T) { func TestUnmarshaling(t *testing.T) {
@ -497,6 +549,29 @@ func TestUnmarshaling(t *testing.T) {
assert.Equal(t, []any{"skateboarding", "snowboarding", "go"}, Get("hobbies")) assert.Equal(t, []any{"skateboarding", "snowboarding", "go"}, Get("hobbies"))
assert.Equal(t, map[string]any{"jacket": "leather", "trousers": "denim", "pants": map[string]any{"size": "large"}}, Get("clothing")) assert.Equal(t, map[string]any{"jacket": "leather", "trousers": "denim", "pants": map[string]any{"size": "large"}}, Get("clothing"))
assert.Equal(t, 35, Get("age")) assert.Equal(t, 35, Get("age"))
// Lower-case key is found.
assert.Equal(t, true, Get("hacker"))
// Upper-case key is found.
assert.Equal(t, true, v.Get("Hacker"))
}
func TestUnmarshaling_CaseSensitive(t *testing.T) {
// Test preserving the case of keys read from a configuration
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetConfigType("yaml")
v.ReadConfig(bytes.NewBuffer(yamlExample))
assert.True(t, v.InConfig("name"))
assert.True(t, v.InConfig("clothing.Jacket"))
assert.False(t, v.InConfig("state"))
assert.False(t, v.InConfig("clothing.hat"))
assert.Equal(t, "steve", v.Get("name"))
assert.Equal(t, []any{"skateboarding", "snowboarding", "go"}, v.Get("hobbies"))
assert.Equal(t, map[string]any{"Jacket": "leather", "trousers": "denim", "Pants": map[string]any{"size": "large"}}, v.Get("clothing"))
assert.Equal(t, 35, v.Get("age"))
assert.Equal(t, true, v.Get("Hacker"))
// Lower-case key is not found.
assert.Equal(t, nil, v.Get("hacker"))
} }
func TestUnmarshalExact(t *testing.T) { func TestUnmarshalExact(t *testing.T) {
@ -730,6 +805,36 @@ func TestEnvSubConfig(t *testing.T) {
assert.Equal(t, "large", subWithPrefix.Get("size")) assert.Equal(t, "large", subWithPrefix.Get("size"))
} }
// This is an interesting case because the key passed to Viper.Get() is
// converted to upper-case to perform a lookup in the environment variables.
// So, this is a case for lookup where case-sensitivity is not strict.
func TestEnvSubConfig_CaseSensitive(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetConfigType("yaml")
r := strings.NewReader(string(yamlExample))
if err := v.unmarshalReader(r, v.config); err != nil {
panic(err)
}
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
t.Setenv("CLOTHING_PANTS_SIZE", "small")
subv := v.Sub("clothing").Sub("Pants")
assert.Equal(t, "small", subv.Get("size"))
assert.Equal(t, "small", v.Get("clothing.Pants.size"))
assert.Equal(t, "small", v.Get("clothing.pants.size"))
// again with EnvPrefix
v.SetEnvPrefix("foo") // will be uppercased automatically
subWithPrefix := v.Sub("clothing").Sub("Pants")
t.Setenv("FOO_CLOTHING_PANTS_SIZE", "large")
assert.Equal(t, "large", subWithPrefix.Get("size"))
assert.Equal(t, "large", v.Get("clothing.Pants.size"))
assert.Equal(t, "large", v.Get("clothing.pants.size"))
}
func TestAllKeys(t *testing.T) { func TestAllKeys(t *testing.T) {
initConfigs() initConfigs()
@ -841,6 +946,118 @@ func TestAllKeys(t *testing.T) {
assert.Equal(t, all, AllSettings()) assert.Equal(t, all, AllSettings())
} }
func TestAllKeys_CaseSensitive(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
initAllConfigs(v)
ks := []string{
"title",
"author.BIO",
"author.E-MAIL",
"author.GITHUB",
"author.NAME",
"newkey",
"owner.organization",
"owner.dob",
"owner.Bio",
"name",
"beard",
"ppu",
"batters.batter",
"hobbies",
"clothing.Jacket",
"clothing.trousers",
"DEFAULT.IMPORT_PATH",
"DEFAULT.NAME",
"DEFAULT.VERSION",
"clothing.Pants.size",
"age",
"Hacker",
"id",
"type",
"eyes",
"p_id",
"p_ppu",
"p_batters.batter.type",
"p_type",
"p_name",
"foos",
"TITLE_DOTENV",
"TYPE_DOTENV",
"NAME_DOTENV",
}
dob, _ := time.Parse(time.RFC3339, "1979-05-27T07:32:00Z")
all := map[string]any{
"owner": map[string]any{
"organization": "MongoDB",
"Bio": "MongoDB Chief Developer Advocate & Hacker at Large",
"dob": dob,
},
"title": "TOML Example",
"author": map[string]any{
"E-MAIL": "fake@localhost",
"GITHUB": "https://github.com/Unknown",
"NAME": "Unknown",
"BIO": "Gopher.\nCoding addict.\nGood man.\n",
},
"ppu": 0.55,
"eyes": "brown",
"clothing": map[string]any{
"trousers": "denim",
"Jacket": "leather",
"Pants": map[string]any{"size": "large"},
},
"DEFAULT": map[string]any{
"IMPORT_PATH": "gopkg.in/ini.v1",
"NAME": "ini",
"VERSION": "v1",
},
"id": "0001",
"batters": map[string]any{
"batter": []any{
map[string]any{"type": "Regular"},
map[string]any{"type": "Chocolate"},
map[string]any{"type": "Blueberry"},
map[string]any{"type": "Devil's Food"},
},
},
"Hacker": true,
"beard": true,
"hobbies": []any{
"skateboarding",
"snowboarding",
"go",
},
"age": 35,
"type": "donut",
"newkey": "remote",
"name": "Cake",
"p_id": "0001",
"p_ppu": "0.55",
"p_name": "Cake",
"p_batters": map[string]any{
"batter": map[string]any{"type": "Regular"},
},
"p_type": "donut",
"foos": []map[string]any{
{
"foo": []map[string]any{
{"key": 1},
{"key": 2},
{"key": 3},
{"key": 4},
},
},
},
"TITLE_DOTENV": "DotEnv Example",
"TYPE_DOTENV": "donut",
"NAME_DOTENV": "Cake",
}
assert.ElementsMatch(t, ks, v.AllKeys())
assert.Equal(t, all, v.AllSettings())
}
func TestAllKeysWithEnv(t *testing.T) { func TestAllKeysWithEnv(t *testing.T) {
v := New() v := New()
@ -1526,6 +1743,8 @@ func TestSub(t *testing.T) {
subv := v.Sub("clothing") subv := v.Sub("clothing")
assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size"))
assert.Equal(t, v.Get("clothing.jacket"), subv.Get("jacket"))
assert.Equal(t, v.Get("clothing.jacket"), subv.Get("Jacket"))
subv = v.Sub("clothing.pants") subv = v.Sub("clothing.pants")
assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size"))
@ -1543,6 +1762,35 @@ func TestSub(t *testing.T) {
assert.Equal(t, []string{"clothing", "pants"}, subv.parents) assert.Equal(t, []string{"clothing", "pants"}, subv.parents)
} }
func TestSub_CaseSensitive(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetConfigType("yaml")
v.ReadConfig(bytes.NewBuffer(yamlExample))
subv := v.Sub("clothing")
assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size"))
assert.Equal(t, v.Get("clothing.Jacket"), subv.Get("Jacket"))
assert.Equal(t, nil, subv.Get("jacket"))
subv = v.Sub("clothing.pants")
assert.Nil(t, subv)
subv = v.Sub("clothing.Pants")
//TODO: test Get with case-sensitive keys that are not found: "clothing.pants.size"
assert.Equal(t, v.Get("clothing.Pants.size"), subv.Get("size"))
subv = v.Sub("clothing.pants.size")
assert.Equal(t, (*Viper)(nil), subv)
subv = v.Sub("missing.key")
assert.Equal(t, (*Viper)(nil), subv)
subv = v.Sub("clothing")
assert.Equal(t, []string{"clothing"}, subv.parents)
subv = v.Sub("clothing").Sub("Pants")
assert.Equal(t, []string{"clothing", "Pants"}, subv.parents)
}
var hclWriteExpected = []byte(`"foos" = { var hclWriteExpected = []byte(`"foos" = {
"foo" = { "foo" = {
"key" = 1 "key" = 1
@ -2046,6 +2294,44 @@ func TestMergeConfigMap(t *testing.T) {
assertFn(1234) assertFn(1234)
} }
func TestMergeConfigMap_CaseSensitive(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetConfigType("yml")
err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt))
require.NoError(t, err)
assertFn := func(i int) {
large := v.GetInt64("hello.largenum")
pop := v.GetInt("hello.pop")
assert.Equal(t, int64(765432101234567), large)
assert.Equal(t, i, pop)
}
assertFn(37890)
update := map[string]any{
"hello": map[string]any{
"Pop": 98,
"pop": 76,
},
"Hello": map[string]any{
"Pop": 1234,
},
"World": map[any]any{
"Rock": 345,
},
}
err = v.MergeConfigMap(update)
require.NoError(t, err)
assert.Equal(t, 345, v.GetInt("World.Rock"))
assert.Equal(t, 1234, v.GetInt("Hello.Pop"))
assert.Equal(t, 98, v.GetInt("hello.Pop"))
assertFn(76)
}
func TestUnmarshalingWithAliases(t *testing.T) { func TestUnmarshalingWithAliases(t *testing.T) {
v := New() v := New()
v.SetDefault("ID", 1) v.SetDefault("ID", 1)
@ -2158,6 +2444,51 @@ R = 6
} }
} }
func TestCaseSensitive(t *testing.T) {
for _, config := range []struct {
typ string
content string
}{
{"yaml", `
aBcD: 1
eF:
gH: 2
iJk: 3
Lm:
nO: 4
P:
Q: 5
R: 6
`},
{"json", `{
"aBcD": 1,
"eF": {
"iJk": 3,
"Lm": {
"P": {
"Q": 5,
"R": 6
},
"nO": 4
},
"gH": 2
}
}`},
{"toml", `aBcD = 1
[eF]
gH = 2
iJk = 3
[eF.Lm]
nO = 4
[eF.Lm.P]
Q = 5
R = 6
`},
} {
doTestCaseSensitive(t, config.typ, config.content)
}
}
func TestCaseInsensitiveSet(t *testing.T) { func TestCaseInsensitiveSet(t *testing.T) {
Reset() Reset()
m1 := map[string]any{ m1 := map[string]any{
@ -2197,6 +2528,58 @@ func TestCaseInsensitiveSet(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
} }
func TestCaseSensitiveSet(t *testing.T) {
v := NewWithOptions(CaseSensitiveKeys(true))
m1 := map[string]any{
"Foo": 32,
"Bar": map[any]any{
"ABc": "A",
"cDE": "B",
},
}
m2 := map[string]any{
"Foo": 52,
"Bar": map[any]any{
"bCd": "A",
"eFG": "B",
},
}
v.Set("Given1", m1)
v.Set("Number1", 42)
v.SetDefault("Given2", m2)
v.SetDefault("Number2", 62)
// Verify SetDefault
assert.Equal(t, 62, v.Get("Number2"))
assert.Equal(t, nil, v.Get("number2"))
assert.Equal(t, 52, v.Get("Given2.Foo"))
assert.Equal(t, nil, v.Get("Given2.foo"))
assert.Equal(t, nil, v.Get("given2.Foo"))
assert.Equal(t, "A", v.Get("Given2.Bar.bCd"))
assert.Equal(t, nil, v.Get("Given2.Bar.bcd"))
assert.Equal(t, nil, v.Get("Given2.bar.bCd"))
assert.Equal(t, nil, v.Get("given2.Bar.bCd"))
_, ok := m2["Foo"]
assert.True(t, ok)
// Verify Set
assert.Equal(t, 42, v.Get("Number1"))
assert.Equal(t, nil, v.Get("number1"))
assert.Equal(t, 32, v.Get("Given1.Foo"))
assert.Equal(t, nil, v.Get("Given1.foo"))
assert.Equal(t, nil, v.Get("given1.Foo"))
assert.Equal(t, "A", v.Get("Given1.Bar.ABc"))
assert.Equal(t, nil, v.Get("Given1.Bar.abc"))
assert.Equal(t, nil, v.Get("Given1.bar.ABc"))
assert.Equal(t, nil, v.Get("given1.Bar.ABc"))
_, ok = m1["Foo"]
assert.True(t, ok)
}
func TestParseNested(t *testing.T) { func TestParseNested(t *testing.T) {
type duration struct { type duration struct {
Delay time.Duration Delay time.Duration
@ -2237,6 +2620,30 @@ func doTestCaseInsensitive(t *testing.T, typ, config string) {
assert.Equal(t, 5, cast.ToInt(Get("ef.lm.p.q"))) assert.Equal(t, 5, cast.ToInt(Get("ef.lm.p.q")))
} }
func doTestCaseSensitive(t *testing.T, typ, config string) {
v := NewWithOptions(CaseSensitiveKeys(true))
v.SetConfigType(typ)
r := strings.NewReader(config)
if err := v.unmarshalReader(r, v.config); err != nil {
panic(err)
}
v.Set("RfD", true)
assert.Equal(t, nil, v.Get("rfd"))
assert.Equal(t, true, v.Get("RfD"))
assert.Equal(t, 0, cast.ToInt(v.Get("abcd")))
assert.Equal(t, 1, cast.ToInt(v.Get("aBcD")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.gh")))
assert.Equal(t, 2, cast.ToInt(v.Get("eF.gH")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.ijk")))
assert.Equal(t, 3, cast.ToInt(v.Get("eF.iJk")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.no")))
assert.Equal(t, 4, cast.ToInt(v.Get("eF.Lm.nO")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.p.q")))
assert.Equal(t, 5, cast.ToInt(v.Get("eF.Lm.P.Q")))
}
func newViperWithConfigFile(t *testing.T) (*Viper, string) { func newViperWithConfigFile(t *testing.T) (*Viper, string) {
watchDir := t.TempDir() watchDir := t.TempDir()
configFile := path.Join(watchDir, "config.yaml") configFile := path.Join(watchDir, "config.yaml")

View file

@ -7,9 +7,9 @@ hobbies:
- snowboarding - snowboarding
- go - go
clothing: clothing:
jacket: leather Jacket: leather
trousers: denim trousers: denim
pants: Pants:
size: large size: large
age: 35 age: 35
eyes : brown eyes : brown