diff --git a/util.go b/util.go index 64e6575..f5272b0 100644 --- a/util.go +++ b/util.go @@ -31,31 +31,31 @@ func (pe ConfigParseError) Error() string { return fmt.Sprintf("While parsing config: %s", pe.err.Error()) } -// toCaseInsensitiveValue checks if the value is a map; -// if so, create a copy and lower-case the keys recursively. -func toCaseInsensitiveValue(value interface{}) interface{} { +// toCaseInsensitiveValue checks if the value is a map; +// if so, create a copy and lower-case (normalize) the keys recursively. +func toCaseInsensitiveValue(value interface{}, normalize keyNormalizeHookType) interface{} { switch v := value.(type) { case map[interface{}]interface{}: - value = copyAndInsensitiviseMap(cast.ToStringMap(v)) + value = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize) case map[string]interface{}: - value = copyAndInsensitiviseMap(v) + value = copyAndInsensitiviseMap(v, normalize) } return value } // copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of -// any map it makes case insensitive. -func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { +// any map it makes case insensitive (normalized). +func copyAndInsensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) map[string]interface{} { nm := make(map[string]interface{}) for key, val := range m { - lkey := strings.ToLower(key) + lkey := normalize(key) switch v := val.(type) { case map[interface{}]interface{}: - nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) + nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize) case map[string]interface{}: - nm[lkey] = copyAndInsensitiviseMap(v) + nm[lkey] = copyAndInsensitiviseMap(v, normalize) default: nm[lkey] = v } @@ -64,26 +64,26 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { return nm } -func insensitiviseVal(val interface{}) interface{} { - switch val.(type) { +func insensitiviseVal(val interface{}, normalize keyNormalizeHookType) interface{} { + switch valT := val.(type) { case map[interface{}]interface{}: // nested map: cast and recursively insensitivise val = cast.ToStringMap(val) - insensitiviseMap(val.(map[string]interface{})) + insensitiviseMap(val.(map[string]interface{}), normalize) case map[string]interface{}: // nested map: recursively insensitivise - insensitiviseMap(val.(map[string]interface{})) + insensitiviseMap(valT, normalize) case []interface{}: // nested array: recursively insensitivise - insensitiveArray(val.([]interface{})) + insensitiveArray(valT, normalize) } return val } -func insensitiviseMap(m map[string]interface{}) { +func insensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) { for key, val := range m { - val = insensitiviseVal(val) - lower := strings.ToLower(key) + val = insensitiviseVal(val, normalize) + lower := normalize(key) if key != lower { // remove old key (not lower-cased) delete(m, key) @@ -93,9 +93,9 @@ func insensitiviseMap(m map[string]interface{}) { } } -func insensitiveArray(a []interface{}) { +func insensitiveArray(a []interface{}, normalize keyNormalizeHookType) { for i, val := range a { - a[i] = insensitiviseVal(val) + a[i] = insensitiviseVal(val, normalize) } } diff --git a/util_test.go b/util_test.go index cb4e620..889bdf7 100644 --- a/util_test.go +++ b/util_test.go @@ -14,6 +14,7 @@ import ( "os" "path/filepath" "reflect" + "strings" "testing" "github.com/spf13/viper/internal/testutil" @@ -37,7 +38,7 @@ func TestCopyAndInsensitiviseMap(t *testing.T) { } ) - got := copyAndInsensitiviseMap(given) + got := copyAndInsensitiviseMap(given, strings.ToLower) if !reflect.DeepEqual(got, expected) { t.Fatalf("Got %q\nexpected\n%q", got, expected) diff --git a/viper.go b/viper.go index fa6f3e3..2f20ce1 100644 --- a/viper.go +++ b/viper.go @@ -142,6 +142,10 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption { } } +type keyNormalizeHookType func(string) string + +var defaultKeyNormalizer = strings.ToLower + // Viper is a prioritized configuration registry. It // maintains a set of configuration sources, fetches // values to populate those, and provides them according @@ -183,6 +187,10 @@ type Viper struct { // used to access a nested value in one go keyDelim string + // Function to normalize keys + // by default, strings.ToLower + keyNormalizeHook keyNormalizeHookType + // A set of paths to look for the config file in configPaths []string @@ -229,6 +237,7 @@ type Viper struct { func New() *Viper { v := new(Viper) v.keyDelim = "." + v.keyNormalizeHook = defaultKeyNormalizer v.configName = "config" v.configPermissions = os.FileMode(0o644) v.fs = afero.NewOsFs() @@ -270,6 +279,23 @@ func KeyDelimiter(d string) Option { }) } +// KeyNormalizer is option to set arbitrary function for key normalization +// This function will be applied to all keys after unmarshal, during merge, search for duplicates, etc +// Default normalizer is strings.ToLower +func KeyNormalizer(n keyNormalizeHookType) Option { + return optionFunc(func(v *Viper) { + v.keyNormalizeHook = n + }) +} + +// KeyPreserveCase is option to disable key lowercasing +// By default, Viper converts all keys to lovercase +func KeyPreserveCase() Option { + return optionFunc(func(v *Viper) { + v.keyNormalizeHook = func(key string) string { return key } + }) +} + // StringReplacer applies a set of replacements to a string. type StringReplacer interface { // Replace returns a copy of s with all replacements performed. @@ -523,6 +549,13 @@ func (v *Viper) SetEnvPrefix(in string) { } } +func (v *Viper) keyNormalize(k string) string { + if v.keyNormalizeHook != nil { + return v.keyNormalizeHook(k) + } + return defaultKeyNormalizer(k) +} + func (v *Viper) mergeWithEnvPrefix(in string) string { if v.envPrefix != "" { return strings.ToUpper(v.envPrefix + "_" + in) @@ -652,7 +685,7 @@ func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool { // searchMap recursively searches for a value for path in source map. // Returns nil if not found. -// Note: This assumes that the path entries and map keys are lower cased. +// Note: This assumes that the path entries and map keys are normalized (by default, lowercased). func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} { if len(path) == 0 { return source @@ -691,7 +724,7 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac // This should be useful only at config level (other maps may not contain dots // in their keys). // -// Note: This assumes that the path entries and map keys are lower cased. +// Note: This assumes that the path entries and map keys are lower cased (normalized). func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []string) interface{} { if len(path) == 0 { return source @@ -699,7 +732,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []strin // search for path prefixes, starting from the longest one for i := len(path); i > 0; i-- { - prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) + prefixKey := v.keyNormalize(strings.Join(path[0:i], v.keyDelim)) var val interface{} switch sourceIndexable := source.(type) { @@ -890,7 +923,7 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { - lcaseKey := strings.ToLower(key) + lcaseKey := v.keyNormalize(key) val := v.find(lcaseKey, true) if val == nil { return nil @@ -950,7 +983,7 @@ func (v *Viper) Sub(key string) *Viper { } if reflect.TypeOf(data).Kind() == reflect.Map { - subv.parents = append(v.parents, strings.ToLower(key)) + subv.parents = append(v.parents, v.keyNormalize(key)) subv.automaticEnvApplied = v.automaticEnvApplied subv.envPrefix = v.envPrefix subv.envKeyReplacer = v.envKeyReplacer @@ -1189,7 +1222,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } - v.pflags[strings.ToLower(key)] = flag + v.pflags[v.keyNormalize(key)] = flag return nil } @@ -1206,7 +1239,7 @@ func (v *Viper) BindEnv(input ...string) error { return fmt.Errorf("missing key to bind to") } - key := strings.ToLower(input[0]) + key := v.keyNormalize(input[0]) if len(input) == 1 { v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) @@ -1236,7 +1269,7 @@ func (v *Viper) MustBindEnv(input ...string) { // Lastly, if no value was found and flagDefault is true, and if the key // corresponds to a flag, the flag's default value is returned. // -// Note: this assumes a lower-cased key given. +// Note: this assumes a normalized key given. func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { var ( val interface{} @@ -1419,12 +1452,12 @@ func stringToStringConv(val string) interface{} { } // IsSet checks to see if the key has been set in any of the data locations. -// IsSet is case-insensitive for a key. +// IsSet normalizes the key. func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { - lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey, false) + normKey := v.keyNormalize(key) + val := v.find(normKey, false) return val != nil } @@ -1450,11 +1483,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func (v *Viper) RegisterAlias(alias string, key string) { - v.registerAlias(alias, strings.ToLower(key)) + v.registerAlias(alias, v.keyNormalize(key)) } func (v *Viper) registerAlias(alias string, key string) { - alias = strings.ToLower(alias) + alias = v.keyNormalize(alias) if alias != key && alias != v.realKey(key) { _, exists := v.aliases[alias] @@ -1499,7 +1532,7 @@ func (v *Viper) realKey(key string) string { func InConfig(key string) bool { return v.InConfig(key) } func (v *Viper) InConfig(key string) bool { - lcaseKey := strings.ToLower(key) + lcaseKey := v.keyNormalize(key) // if the requested key is an alias, then return the proper key lcaseKey = v.realKey(lcaseKey) @@ -1509,17 +1542,18 @@ func (v *Viper) InConfig(key string) bool { } // SetDefault sets the default value for this key. -// SetDefault is case-insensitive for a key. +// SetDefault applies normalization (by default, lowercases) a key. // Default only used when no value is provided by the user via flag, config or ENV. func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func (v *Viper) SetDefault(key string, value interface{}) { // If alias passed in, then set the proper default - key = v.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) + key = v.keyNormalize(key) + value = toCaseInsensitiveValue(value, v.keyNormalize) + key = v.realKey(key) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.keyNormalize(path[len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) // set innermost value @@ -1527,18 +1561,19 @@ func (v *Viper) SetDefault(key string, value interface{}) { } // Set sets the value for the key in the override register. -// Set is case-insensitive for a key. +// Set normalizes a key. // Will be used instead of values obtained via // flags, config file, ENV, default, or key/value store. func Set(key string, value interface{}) { v.Set(key, value) } func (v *Viper) Set(key string, value interface{}) { // If alias passed in, then set the proper override - key = v.realKey(strings.ToLower(key)) - value = toCaseInsensitiveValue(value) + key = v.keyNormalize(key) + value = toCaseInsensitiveValue(value, v.keyNormalize) + key = v.realKey(key) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.keyNormalize(path[len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1]) // set innermost value @@ -1627,7 +1662,7 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { if v.config == nil { v.config = make(map[string]interface{}) } - insensitiviseMap(cfg) + insensitiviseMap(cfg, v.keyNormalize) mergeMaps(cfg, v.config, nil) return nil } @@ -1727,7 +1762,8 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { } } - insensitiviseMap(c) + insensitiviseMap(c, v.keyNormalize) + return nil } @@ -1750,9 +1786,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error { } func keyExists(k string, m map[string]interface{}) string { - lk := strings.ToLower(k) + lk := v.keyNormalize(k) for mk := range m { - lmk := strings.ToLower(mk) + lmk := v.keyNormalize(mk) if lmk == lk { return mk } @@ -2031,7 +2067,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = true + shadow[v.keyNormalize(fullKey)] = true continue } // recursively merge to shadow map @@ -2057,7 +2093,7 @@ outer: } } // add key - shadow[strings.ToLower(k)] = true + shadow[v.keyNormalize(k)] = true } return shadow } @@ -2076,7 +2112,7 @@ func (v *Viper) AllSettings() map[string]interface{} { continue } path := strings.Split(k, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := v.keyNormalize(path[len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1]) // set innermost value deepestMap[lastKey] = value diff --git a/viper_test.go b/viper_test.go index b867337..871ac83 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2334,6 +2334,76 @@ func TestCaseInsensitiveSet(t *testing.T) { } } +func TestCaseSensitive(t *testing.T) { + for _, config := range []struct { + typ string + content string + }{ + {"yaml", ` +aBcD: 1 +eF: + gH: 2 + iJk: 3 + Lm: + nO: 4 + P: + Q: 5 + R: 6 +`}, + {"json", `{ + "aBcD": 1, + "eF": { + "iJk": 3, + "Lm": { + "P": { + "Q": 5, + "R": 6 + }, + "nO": 4 + }, + "gH": 2 + } +}`}, + {"toml", `aBcD = 1 +[eF] +gH = 2 +iJk = 3 +[eF.Lm] +nO = 4 +[eF.Lm.P] +Q = 5 +R = 6 +`}, + } { + doTestCaseSensitive(t, config.typ, config.content) + } +} + +func doTestCaseSensitive(t *testing.T, typ, config string) { + // Create case-sensitive instance + v := NewWithOptions(KeyPreserveCase()) + v.SetConfigType(typ) + + r := strings.NewReader(config) + if err := v.unmarshalReader(r, v.config); err != nil { + panic(err) + } + + v.Set("RfD", true) + assert.Equal(t, nil, v.Get("rfd")) + assert.Equal(t, true, v.Get("RfD")) + assert.Equal(t, 0, cast.ToInt(v.Get("abcd"))) + assert.Equal(t, 1, cast.ToInt(v.Get("aBcD"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.gh"))) + assert.Equal(t, 2, cast.ToInt(v.Get("eF.gH"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.ijk"))) + assert.Equal(t, 3, cast.ToInt(v.Get("eF.iJk"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.no"))) + assert.Equal(t, 4, cast.ToInt(v.Get("eF.Lm.nO"))) + assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.p.q"))) + assert.Equal(t, 5, cast.ToInt(v.Get("eF.Lm.P.Q"))) +} + func TestParseNested(t *testing.T) { type duration struct { Delay time.Duration