diff --git a/util.go b/util.go index 52116ac..c6b3bd7 100644 --- a/util.go +++ b/util.go @@ -37,39 +37,42 @@ func (pe ConfigParseError) Unwrap() error { return pe.err } -// toCaseInsensitiveValue checks if the value is a map; -// if so, create a copy and lower-case the keys recursively. -func toCaseInsensitiveValue(value any) any { +// CopyMap returns a deep copy of a map[any]any or map[string]any. If value is +// not one of those map types, then it is returned as-is. If preserveCase is +// false, then all keys will be converted to lower-case in the copy that is +// returned. +func CopyMap(value any, preserveCase bool) any { + var copyMap func(map[string]any, bool) map[string]any + copyMap = func(m map[string]any, preserveCase bool) map[string]any { + nm := make(map[string]any) + + for key, val := range m { + if !preserveCase { + key = strings.ToLower(key) + } + switch v := val.(type) { + case map[any]any: + nm[key] = copyMap(cast.ToStringMap(v), preserveCase) + case map[string]any: + nm[key] = copyMap(v, preserveCase) + default: + nm[key] = v + } + } + + return nm + } + switch v := value.(type) { case map[any]any: - value = copyAndInsensitiviseMap(cast.ToStringMap(v)) + value = copyMap(cast.ToStringMap(v), preserveCase) case map[string]any: - value = copyAndInsensitiviseMap(v) + value = copyMap(v, preserveCase) } return value } -// copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of -// any map it makes case insensitive. -func copyAndInsensitiviseMap(m map[string]any) map[string]any { - nm := make(map[string]any) - - for key, val := range m { - lkey := strings.ToLower(key) - switch v := val.(type) { - case map[any]any: - nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) - case map[string]any: - nm[lkey] = copyAndInsensitiviseMap(v) - default: - nm[lkey] = v - } - } - - return nm -} - func insensitiviseVal(val any) any { switch v := val.(type) { case map[any]any: diff --git a/util_test.go b/util_test.go index 8d0bda8..d922947 100644 --- a/util_test.go +++ b/util_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCopyAndInsensitiviseMap(t *testing.T) { +func TestCopyMap(t *testing.T) { var ( given = map[string]any{ "Foo": 32, @@ -35,19 +35,54 @@ func TestCopyAndInsensitiviseMap(t *testing.T) { "cde": "B", }, } + expectedPreserveCase = map[string]any{ + "Foo": 32, + "Bar": map[string]any{ + "ABc": "A", + "cDE": "B", + }, + } ) - got := copyAndInsensitiviseMap(given) + t.Run("convert to lower-case", func(t *testing.T) { + got := CopyMap(given, false) - assert.Equal(t, expected, got) - _, ok := given["foo"] - assert.False(t, ok) - _, ok = given["bar"] - assert.False(t, ok) + assert.Equal(t, expected, got) + _, ok := given["foo"] + assert.False(t, ok) + _, ok = given["bar"] + assert.False(t, ok) - m := given["Bar"].(map[any]any) - _, ok = m["ABc"] - assert.True(t, ok) + m := given["Bar"].(map[any]any) + _, ok = m["ABc"] + assert.True(t, ok) + }) + + t.Run("preserve case", func(t *testing.T) { + got := CopyMap(given, true) + + assert.Equal(t, expectedPreserveCase, got) + _, ok := given["foo"] + assert.False(t, ok) + _, ok = given["bar"] + assert.False(t, ok) + + m := given["Bar"].(map[any]any) + _, ok = m["ABc"] + assert.True(t, ok) + }) + + t.Run("not a map", func(t *testing.T) { + var ( + given = []any{42, "xyz"} + expected = []any{42, "xyz"} + ) + got := CopyMap(given, false) + assert.Equal(t, expected, got) + + got = CopyMap(given, true) + assert.Equal(t, expected, got) + }) } func TestAbsPathify(t *testing.T) { diff --git a/viper.go b/viper.go index c1eab71..72bac3e 100644 --- a/viper.go +++ b/viper.go @@ -206,6 +206,10 @@ type Viper struct { envKeyReplacer StringReplacer allowEmptyEnv bool + // When caseSensitiveKeys is true, keys are preserved in their original + // case (i.e., not modified to lower-case). + caseSensitiveKeys bool + parents []string config map[string]any override map[string]any @@ -283,6 +287,20 @@ func EnvKeyReplacer(r StringReplacer) Option { }) } +// CaseSensitiveKeys sets Viper to use case-sensitive keys, which preserves the +// case of keys. By default, all keys are converted to lower-case. +// +// Lookup keys (i.e., Get()) will always match environment variables in a +// case-insensitive manner because Viper always converts the lookup key to +// upper-case when searching for an environment variable. The case of the +// actual keys will be preserved, however, as seen in the output of AllKeys(), +// AllSettings(), etc. +func CaseSensitiveKeys(enable bool) Option { + return optionFunc(func(v *Viper) { + v.caseSensitiveKeys = enable + }) +} + // NewWithOptions creates a new Viper instance. func NewWithOptions(opts ...Option) *Viper { v := New() @@ -706,7 +724,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source any, path []string) any { // search for path prefixes, starting from the longest one for i := len(path); i > 0; i-- { - prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) + prefixKey := v.toLower(strings.Join(path[0:i], v.keyDelim)) var val any switch sourceIndexable := source.(type) { @@ -897,8 +915,7 @@ func GetViper() *Viper { func Get(key string) any { return v.Get(key) } func (v *Viper) Get(key string) any { - lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey, true) + val := v.find(v.toLower(key), true) if val == nil { return nil } @@ -906,7 +923,7 @@ func (v *Viper) Get(key string) any { if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val - path := strings.Split(lcaseKey, v.keyDelim) + path := strings.Split(key, v.keyDelim) defVal := v.searchMap(v.defaults, path) if defVal != nil { valType = defVal @@ -950,14 +967,14 @@ func (v *Viper) Get(key string) any { func Sub(key string) *Viper { return v.Sub(key) } func (v *Viper) Sub(key string) *Viper { - subv := New() data := v.Get(key) if data == nil { return nil } if reflect.TypeOf(data).Kind() == reflect.Map { - subv.parents = append(v.parents, strings.ToLower(key)) + subv := NewWithOptions(CaseSensitiveKeys(v.caseSensitiveKeys)) + subv.parents = append(v.parents, v.toLower(key)) subv.automaticEnvApplied = v.automaticEnvApplied subv.envPrefix = v.envPrefix subv.envKeyReplacer = v.envKeyReplacer @@ -1196,7 +1213,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } - v.pflags[strings.ToLower(key)] = flag + v.pflags[v.toLower(key)] = flag return nil } @@ -1213,7 +1230,7 @@ func (v *Viper) BindEnv(input ...string) error { return fmt.Errorf("missing key to bind to") } - key := strings.ToLower(input[0]) + key := v.toLower(input[0]) if len(input) == 1 { v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) @@ -1243,12 +1260,13 @@ func (v *Viper) MustBindEnv(input ...string) { // Lastly, if no value was found and flagDefault is true, and if the key // corresponds to a flag, the flag's default value is returned. // -// Note: this assumes a lower-cased key given. -func (v *Viper) find(lcaseKey string, flagDefault bool) any { +// Note: By default, this assumes that a lowercase key is given. +// This behavior can be modified with viper.SetPreserveCase(). +func (v *Viper) find(key string, flagDefault bool) any { var ( val any exists bool - path = strings.Split(lcaseKey, v.keyDelim) + path = strings.Split(key, v.keyDelim) nested = len(path) > 1 ) @@ -1258,8 +1276,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any { } // if the requested key is an alias, then return the proper key - lcaseKey = v.realKey(lcaseKey) - path = strings.Split(lcaseKey, v.keyDelim) + key = v.realKey(key) + path = strings.Split(key, v.keyDelim) nested = len(path) > 1 // Set() override first @@ -1272,7 +1290,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any { } // PFlag override next - flag, exists := v.pflags[lcaseKey] + flag, exists := v.pflags[key] if exists && flag.HasChanged() { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": @@ -1308,7 +1326,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any { // Env override next if v.automaticEnvApplied { - envKey := strings.Join(append(v.parents, lcaseKey), ".") + envKey := strings.Join(append(v.parents, key), ".") // even if it hasn't been registered, if automaticEnv is used, // check any Get request if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { @@ -1318,7 +1336,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any { return nil } } - envkeys, exists := v.env[lcaseKey] + envkeys, exists := v.env[key] if exists { for _, envkey := range envkeys { if val, ok := v.getEnv(envkey); ok { @@ -1360,7 +1378,7 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) any { if flagDefault { // last chance: if no value is found and a flag does exist for the key, // get the flag's default value even if the flag's value has not been set. - if flag, exists := v.pflags[lcaseKey]; exists { + if flag, exists := v.pflags[key]; exists { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": return cast.ToInt(flag.ValueString()) @@ -1457,8 +1475,7 @@ func stringToIntConv(val string) any { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { - lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey, false) + val := v.find(v.toLower(key), false) return val != nil } @@ -1484,11 +1501,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func (v *Viper) RegisterAlias(alias string, key string) { - v.registerAlias(alias, strings.ToLower(key)) + v.registerAlias(alias, v.toLower(key)) } func (v *Viper) registerAlias(alias string, key string) { - alias = strings.ToLower(alias) + alias = v.toLower(alias) if alias != key && alias != v.realKey(key) { _, exists := v.aliases[alias] @@ -1533,11 +1550,9 @@ func (v *Viper) realKey(key string) string { func InConfig(key string) bool { return v.InConfig(key) } func (v *Viper) InConfig(key string) bool { - lcaseKey := strings.ToLower(key) - // if the requested key is an alias, then return the proper key - lcaseKey = v.realKey(lcaseKey) - path := strings.Split(lcaseKey, v.keyDelim) + key = v.realKey(v.toLower(key)) + path := strings.Split(key, v.keyDelim) return v.searchIndexableWithPathPrefixes(v.config, path) != nil } @@ -1549,11 +1564,11 @@ func SetDefault(key string, value any) { v.SetDefault(key, value) } func (v *Viper) SetDefault(key string, value any) { // If alias passed in, then set the proper default - key = v.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) + key = v.realKey(v.toLower(key)) + value = CopyMap(value, v.caseSensitiveKeys) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.toLower(path[len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) // set innermost value @@ -1567,12 +1582,12 @@ func (v *Viper) SetDefault(key string, value any) { func Set(key string, value any) { v.Set(key, value) } func (v *Viper) Set(key string, value any) { - // If alias passed in, then set the proper override - key = v.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) + // If alias passed in, then set the proper default + key = v.realKey(v.toLower(key)) + value = CopyMap(value, v.caseSensitiveKeys) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.toLower(path[len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1]) // set innermost value @@ -1661,8 +1676,10 @@ func (v *Viper) MergeConfigMap(cfg map[string]any) error { if v.config == nil { v.config = make(map[string]any) } - insensitiviseMap(cfg) - mergeMaps(cfg, v.config, nil) + if !v.caseSensitiveKeys { + insensitiviseMap(cfg) + } + v.mergeMaps(cfg, v.config, nil) return nil } @@ -1761,7 +1778,9 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]any) error { } } - insensitiviseMap(c) + if !v.caseSensitiveKeys { + insensitiviseMap(c) + } return nil } @@ -1783,10 +1802,10 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error { return nil } -func keyExists(k string, m map[string]any) string { - lk := strings.ToLower(k) +func (v *Viper) keyExists(k string, m map[string]any) string { + lk := v.toLower(k) for mk := range m { - lmk := strings.ToLower(mk) + lmk := v.toLower(mk) if lmk == lk { return mk } @@ -1833,9 +1852,9 @@ func castMapFlagToMapInterface(src map[string]FlagValue) map[string]any { // instead of using a `string` as the key for nest structures beyond one level // deep. Both map types are supported as there is a go-yaml fork that uses // `map[string]any` instead. -func mergeMaps(src, tgt map[string]any, itgt map[any]any) { +func (v *Viper) mergeMaps(src, tgt map[string]any, itgt map[any]any) { for sk, sv := range src { - tk := keyExists(sk, tgt) + tk := v.keyExists(sk, tgt) if tk == "" { v.logger.Debug("", "tk", "\"\"", fmt.Sprintf("tgt[%s]", sk), sv) tgt[sk] = sv @@ -1885,7 +1904,7 @@ func mergeMaps(src, tgt map[string]any, itgt map[any]any) { ssv := castToMapStringInterface(tsv) stv := castToMapStringInterface(ttv) - mergeMaps(ssv, stv, ttv) + v.mergeMaps(ssv, stv, ttv) case map[string]any: v.logger.Debug("merging maps") tsv, ok := sv.(map[string]any) @@ -1900,7 +1919,7 @@ func mergeMaps(src, tgt map[string]any, itgt map[any]any) { ) continue } - mergeMaps(tsv, ttv, nil) + v.mergeMaps(tsv, ttv, nil) default: v.logger.Debug("setting value") tgt[tk] = sv @@ -2063,7 +2082,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]any, pre m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = true + shadow[v.toLower(fullKey)] = true continue } // recursively merge to shadow map @@ -2089,7 +2108,7 @@ outer: } } // add key - shadow[strings.ToLower(k)] = true + shadow[v.toLower(k)] = true } return shadow } @@ -2108,7 +2127,7 @@ func (v *Viper) AllSettings() map[string]any { continue } path := strings.Split(k, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.toLower(path[len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1]) // set innermost value deepestMap[lastKey] = value @@ -2204,3 +2223,13 @@ func (v *Viper) DebugTo(w io.Writer) { fmt.Fprintf(w, "Config:\n%#v\n", v.config) fmt.Fprintf(w, "Defaults:\n%#v\n", v.defaults) } + +// toLower returns a properly cased key based on the CaseSensitiveKeys option. +// If preserveCase is true, then the unmodifed key is returned. Otherwise, the +// lower-cased key is returned. +func (v *Viper) toLower(k string) string { + if v.caseSensitiveKeys { + return k + } + return strings.ToLower(k) +} diff --git a/viper_test.go b/viper_test.go index 6b1b316..fdc6703 100644 --- a/viper_test.go +++ b/viper_test.go @@ -171,6 +171,41 @@ func initConfigs() { unmarshalReader(r, v.config) } +func initAllConfigs(v *Viper) { + var r io.Reader + v.SetConfigType("yaml") + r = bytes.NewReader(yamlExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("json") + r = bytes.NewReader(jsonExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("hcl") + r = bytes.NewReader(hclExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("properties") + r = bytes.NewReader(propertiesExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("toml") + r = bytes.NewReader(tomlExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("env") + r = bytes.NewReader(dotenvExample) + v.unmarshalReader(r, v.config) + + v.SetConfigType("json") + remote := bytes.NewReader(remoteExample) + v.unmarshalReader(remote, v.kvstore) + + v.SetConfigType("ini") + r = bytes.NewReader(iniExample) + v.unmarshalReader(r, v.config) +} + func initConfig(typ, config string) { Reset() SetConfigType(typ) @@ -482,6 +517,23 @@ func TestDefault(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "leather", Get("clothing.jacket")) + assert.Equal(t, "leather", Get("clothing.Jacket")) +} + +func TestDefault_CaseSensitive(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetDefault("age", 45) + assert.Equal(t, 45, v.Get("age")) + + v.SetDefault("clothing.jacket", "slacks") + assert.Equal(t, "slacks", v.Get("clothing.jacket")) + + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBuffer(yamlExample)) + + assert.NoError(t, err) + assert.Equal(t, "slacks", v.Get("clothing.jacket")) + assert.Equal(t, "leather", v.Get("clothing.Jacket")) } func TestUnmarshaling(t *testing.T) { @@ -497,6 +549,29 @@ func TestUnmarshaling(t *testing.T) { assert.Equal(t, []any{"skateboarding", "snowboarding", "go"}, Get("hobbies")) assert.Equal(t, map[string]any{"jacket": "leather", "trousers": "denim", "pants": map[string]any{"size": "large"}}, Get("clothing")) assert.Equal(t, 35, Get("age")) + // Lower-case key is found. + assert.Equal(t, true, Get("hacker")) + // Upper-case key is found. + assert.Equal(t, true, v.Get("Hacker")) +} + +func TestUnmarshaling_CaseSensitive(t *testing.T) { + // Test preserving the case of keys read from a configuration + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(yamlExample)) + + assert.True(t, v.InConfig("name")) + assert.True(t, v.InConfig("clothing.Jacket")) + assert.False(t, v.InConfig("state")) + assert.False(t, v.InConfig("clothing.hat")) + assert.Equal(t, "steve", v.Get("name")) + assert.Equal(t, []any{"skateboarding", "snowboarding", "go"}, v.Get("hobbies")) + assert.Equal(t, map[string]any{"Jacket": "leather", "trousers": "denim", "Pants": map[string]any{"size": "large"}}, v.Get("clothing")) + assert.Equal(t, 35, v.Get("age")) + assert.Equal(t, true, v.Get("Hacker")) + // Lower-case key is not found. + assert.Equal(t, nil, v.Get("hacker")) } func TestUnmarshalExact(t *testing.T) { @@ -730,6 +805,36 @@ func TestEnvSubConfig(t *testing.T) { assert.Equal(t, "large", subWithPrefix.Get("size")) } +// This is an interesting case because the key passed to Viper.Get() is +// converted to upper-case to perform a lookup in the environment variables. +// So, this is a case for lookup where case-sensitivity is not strict. +func TestEnvSubConfig_CaseSensitive(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetConfigType("yaml") + r := strings.NewReader(string(yamlExample)) + if err := v.unmarshalReader(r, v.config); err != nil { + panic(err) + } + + v.AutomaticEnv() + + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + + t.Setenv("CLOTHING_PANTS_SIZE", "small") + subv := v.Sub("clothing").Sub("Pants") + assert.Equal(t, "small", subv.Get("size")) + assert.Equal(t, "small", v.Get("clothing.Pants.size")) + assert.Equal(t, "small", v.Get("clothing.pants.size")) + + // again with EnvPrefix + v.SetEnvPrefix("foo") // will be uppercased automatically + subWithPrefix := v.Sub("clothing").Sub("Pants") + t.Setenv("FOO_CLOTHING_PANTS_SIZE", "large") + assert.Equal(t, "large", subWithPrefix.Get("size")) + assert.Equal(t, "large", v.Get("clothing.Pants.size")) + assert.Equal(t, "large", v.Get("clothing.pants.size")) +} + func TestAllKeys(t *testing.T) { initConfigs() @@ -841,6 +946,118 @@ func TestAllKeys(t *testing.T) { assert.Equal(t, all, AllSettings()) } +func TestAllKeys_CaseSensitive(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + initAllConfigs(v) + + ks := []string{ + "title", + "author.BIO", + "author.E-MAIL", + "author.GITHUB", + "author.NAME", + "newkey", + "owner.organization", + "owner.dob", + "owner.Bio", + "name", + "beard", + "ppu", + "batters.batter", + "hobbies", + "clothing.Jacket", + "clothing.trousers", + "DEFAULT.IMPORT_PATH", + "DEFAULT.NAME", + "DEFAULT.VERSION", + "clothing.Pants.size", + "age", + "Hacker", + "id", + "type", + "eyes", + "p_id", + "p_ppu", + "p_batters.batter.type", + "p_type", + "p_name", + "foos", + "TITLE_DOTENV", + "TYPE_DOTENV", + "NAME_DOTENV", + } + dob, _ := time.Parse(time.RFC3339, "1979-05-27T07:32:00Z") + all := map[string]any{ + "owner": map[string]any{ + "organization": "MongoDB", + "Bio": "MongoDB Chief Developer Advocate & Hacker at Large", + "dob": dob, + }, + "title": "TOML Example", + "author": map[string]any{ + "E-MAIL": "fake@localhost", + "GITHUB": "https://github.com/Unknown", + "NAME": "Unknown", + "BIO": "Gopher.\nCoding addict.\nGood man.\n", + }, + "ppu": 0.55, + "eyes": "brown", + "clothing": map[string]any{ + "trousers": "denim", + "Jacket": "leather", + "Pants": map[string]any{"size": "large"}, + }, + "DEFAULT": map[string]any{ + "IMPORT_PATH": "gopkg.in/ini.v1", + "NAME": "ini", + "VERSION": "v1", + }, + "id": "0001", + "batters": map[string]any{ + "batter": []any{ + map[string]any{"type": "Regular"}, + map[string]any{"type": "Chocolate"}, + map[string]any{"type": "Blueberry"}, + map[string]any{"type": "Devil's Food"}, + }, + }, + "Hacker": true, + "beard": true, + "hobbies": []any{ + "skateboarding", + "snowboarding", + "go", + }, + "age": 35, + "type": "donut", + "newkey": "remote", + "name": "Cake", + "p_id": "0001", + "p_ppu": "0.55", + "p_name": "Cake", + "p_batters": map[string]any{ + "batter": map[string]any{"type": "Regular"}, + }, + "p_type": "donut", + "foos": []map[string]any{ + { + "foo": []map[string]any{ + {"key": 1}, + {"key": 2}, + {"key": 3}, + {"key": 4}, + }, + }, + }, + "TITLE_DOTENV": "DotEnv Example", + "TYPE_DOTENV": "donut", + "NAME_DOTENV": "Cake", + } + + assert.ElementsMatch(t, ks, v.AllKeys()) + assert.Equal(t, all, v.AllSettings()) +} + func TestAllKeysWithEnv(t *testing.T) { v := New() @@ -1526,6 +1743,8 @@ func TestSub(t *testing.T) { subv := v.Sub("clothing") assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + assert.Equal(t, v.Get("clothing.jacket"), subv.Get("jacket")) + assert.Equal(t, v.Get("clothing.jacket"), subv.Get("Jacket")) subv = v.Sub("clothing.pants") assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) @@ -1543,6 +1762,35 @@ func TestSub(t *testing.T) { assert.Equal(t, []string{"clothing", "pants"}, subv.parents) } +func TestSub_CaseSensitive(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(yamlExample)) + + subv := v.Sub("clothing") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + assert.Equal(t, v.Get("clothing.Jacket"), subv.Get("Jacket")) + assert.Equal(t, nil, subv.Get("jacket")) + + subv = v.Sub("clothing.pants") + assert.Nil(t, subv) + subv = v.Sub("clothing.Pants") + //TODO: test Get with case-sensitive keys that are not found: "clothing.pants.size" + assert.Equal(t, v.Get("clothing.Pants.size"), subv.Get("size")) + + subv = v.Sub("clothing.pants.size") + assert.Equal(t, (*Viper)(nil), subv) + + subv = v.Sub("missing.key") + assert.Equal(t, (*Viper)(nil), subv) + + subv = v.Sub("clothing") + assert.Equal(t, []string{"clothing"}, subv.parents) + + subv = v.Sub("clothing").Sub("Pants") + assert.Equal(t, []string{"clothing", "Pants"}, subv.parents) +} + var hclWriteExpected = []byte(`"foos" = { "foo" = { "key" = 1 @@ -2046,6 +2294,44 @@ func TestMergeConfigMap(t *testing.T) { assertFn(1234) } +func TestMergeConfigMap_CaseSensitive(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetConfigType("yml") + err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt)) + require.NoError(t, err) + + assertFn := func(i int) { + large := v.GetInt64("hello.largenum") + pop := v.GetInt("hello.pop") + assert.Equal(t, int64(765432101234567), large) + assert.Equal(t, i, pop) + } + + assertFn(37890) + + update := map[string]any{ + "hello": map[string]any{ + "Pop": 98, + "pop": 76, + }, + "Hello": map[string]any{ + "Pop": 1234, + }, + "World": map[any]any{ + "Rock": 345, + }, + } + + err = v.MergeConfigMap(update) + require.NoError(t, err) + + assert.Equal(t, 345, v.GetInt("World.Rock")) + assert.Equal(t, 1234, v.GetInt("Hello.Pop")) + + assert.Equal(t, 98, v.GetInt("hello.Pop")) + assertFn(76) +} + func TestUnmarshalingWithAliases(t *testing.T) { v := New() v.SetDefault("ID", 1) @@ -2158,6 +2444,51 @@ R = 6 } } +func TestCaseSensitive(t *testing.T) { + for _, config := range []struct { + typ string + content string + }{ + {"yaml", ` +aBcD: 1 +eF: + gH: 2 + iJk: 3 + Lm: + nO: 4 + P: + Q: 5 + R: 6 +`}, + {"json", `{ + "aBcD": 1, + "eF": { + "iJk": 3, + "Lm": { + "P": { + "Q": 5, + "R": 6 + }, + "nO": 4 + }, + "gH": 2 + } +}`}, + {"toml", `aBcD = 1 +[eF] +gH = 2 +iJk = 3 +[eF.Lm] +nO = 4 +[eF.Lm.P] +Q = 5 +R = 6 +`}, + } { + doTestCaseSensitive(t, config.typ, config.content) + } +} + func TestCaseInsensitiveSet(t *testing.T) { Reset() m1 := map[string]any{ @@ -2197,6 +2528,58 @@ func TestCaseInsensitiveSet(t *testing.T) { assert.True(t, ok) } +func TestCaseSensitiveSet(t *testing.T) { + v := NewWithOptions(CaseSensitiveKeys(true)) + + m1 := map[string]any{ + "Foo": 32, + "Bar": map[any]any{ + "ABc": "A", + "cDE": "B", + }, + } + + m2 := map[string]any{ + "Foo": 52, + "Bar": map[any]any{ + "bCd": "A", + "eFG": "B", + }, + } + + v.Set("Given1", m1) + v.Set("Number1", 42) + + v.SetDefault("Given2", m2) + v.SetDefault("Number2", 62) + + // Verify SetDefault + assert.Equal(t, 62, v.Get("Number2")) + assert.Equal(t, nil, v.Get("number2")) + assert.Equal(t, 52, v.Get("Given2.Foo")) + assert.Equal(t, nil, v.Get("Given2.foo")) + assert.Equal(t, nil, v.Get("given2.Foo")) + assert.Equal(t, "A", v.Get("Given2.Bar.bCd")) + assert.Equal(t, nil, v.Get("Given2.Bar.bcd")) + assert.Equal(t, nil, v.Get("Given2.bar.bCd")) + assert.Equal(t, nil, v.Get("given2.Bar.bCd")) + _, ok := m2["Foo"] + assert.True(t, ok) + + // Verify Set + assert.Equal(t, 42, v.Get("Number1")) + assert.Equal(t, nil, v.Get("number1")) + assert.Equal(t, 32, v.Get("Given1.Foo")) + assert.Equal(t, nil, v.Get("Given1.foo")) + assert.Equal(t, nil, v.Get("given1.Foo")) + assert.Equal(t, "A", v.Get("Given1.Bar.ABc")) + assert.Equal(t, nil, v.Get("Given1.Bar.abc")) + assert.Equal(t, nil, v.Get("Given1.bar.ABc")) + assert.Equal(t, nil, v.Get("given1.Bar.ABc")) + _, ok = m1["Foo"] + assert.True(t, ok) +} + func TestParseNested(t *testing.T) { type duration struct { Delay time.Duration @@ -2237,6 +2620,30 @@ func doTestCaseInsensitive(t *testing.T, typ, config string) { assert.Equal(t, 5, cast.ToInt(Get("ef.lm.p.q"))) } +func doTestCaseSensitive(t *testing.T, typ, config string) { + v := NewWithOptions(CaseSensitiveKeys(true)) + v.SetConfigType(typ) + + r := strings.NewReader(config) + if err := v.unmarshalReader(r, v.config); err != nil { + panic(err) + } + + v.Set("RfD", true) + assert.Equal(t, nil, v.Get("rfd")) + assert.Equal(t, true, v.Get("RfD")) + assert.Equal(t, 0, cast.ToInt(v.Get("abcd"))) + assert.Equal(t, 1, cast.ToInt(v.Get("aBcD"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.gh"))) + assert.Equal(t, 2, cast.ToInt(v.Get("eF.gH"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.ijk"))) + assert.Equal(t, 3, cast.ToInt(v.Get("eF.iJk"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.no"))) + assert.Equal(t, 4, cast.ToInt(v.Get("eF.Lm.nO"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.p.q"))) + assert.Equal(t, 5, cast.ToInt(v.Get("eF.Lm.P.Q"))) +} + func newViperWithConfigFile(t *testing.T) (*Viper, string) { watchDir := t.TempDir() configFile := path.Join(watchDir, "config.yaml") diff --git a/viper_yaml_test.go b/viper_yaml_test.go index 264446b..1405ce8 100644 --- a/viper_yaml_test.go +++ b/viper_yaml_test.go @@ -7,9 +7,9 @@ hobbies: - snowboarding - go clothing: - jacket: leather + Jacket: leather trousers: denim - pants: + Pants: size: large age: 35 eyes : brown