Add options KeyPreserveCase() and KeyNormalizer(func (string) string)

This commit is contained in:
Dmitry Irtegov 2020-03-03 00:14:43 +07:00 committed by Oleg Chunikhin
parent 3970ad177e
commit bde4ac488a
4 changed files with 157 additions and 50 deletions

38
util.go
View file

@ -32,30 +32,30 @@ func (pe ConfigParseError) Error() string {
} }
// toCaseInsensitiveValue checks if the value is a map; // toCaseInsensitiveValue checks if the value is a map;
// if so, create a copy and lower-case the keys recursively. // if so, create a copy and lower-case (normalize) the keys recursively.
func toCaseInsensitiveValue(value interface{}) interface{} { func toCaseInsensitiveValue(value interface{}, normalize keyNormalizeHookType) interface{} {
switch v := value.(type) { switch v := value.(type) {
case map[interface{}]interface{}: case map[interface{}]interface{}:
value = copyAndInsensitiviseMap(cast.ToStringMap(v)) value = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize)
case map[string]interface{}: case map[string]interface{}:
value = copyAndInsensitiviseMap(v) value = copyAndInsensitiviseMap(v, normalize)
} }
return value return value
} }
// copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of // copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of
// any map it makes case insensitive. // any map it makes case insensitive (normalized).
func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { func copyAndInsensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) map[string]interface{} {
nm := make(map[string]interface{}) nm := make(map[string]interface{})
for key, val := range m { for key, val := range m {
lkey := strings.ToLower(key) lkey := normalize(key)
switch v := val.(type) { switch v := val.(type) {
case map[interface{}]interface{}: case map[interface{}]interface{}:
nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize)
case map[string]interface{}: case map[string]interface{}:
nm[lkey] = copyAndInsensitiviseMap(v) nm[lkey] = copyAndInsensitiviseMap(v, normalize)
default: default:
nm[lkey] = v nm[lkey] = v
} }
@ -64,26 +64,26 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
return nm return nm
} }
func insensitiviseVal(val interface{}) interface{} { func insensitiviseVal(val interface{}, normalize keyNormalizeHookType) interface{} {
switch val.(type) { switch valT := val.(type) {
case map[interface{}]interface{}: case map[interface{}]interface{}:
// nested map: cast and recursively insensitivise // nested map: cast and recursively insensitivise
val = cast.ToStringMap(val) val = cast.ToStringMap(val)
insensitiviseMap(val.(map[string]interface{})) insensitiviseMap(val.(map[string]interface{}), normalize)
case map[string]interface{}: case map[string]interface{}:
// nested map: recursively insensitivise // nested map: recursively insensitivise
insensitiviseMap(val.(map[string]interface{})) insensitiviseMap(valT, normalize)
case []interface{}: case []interface{}:
// nested array: recursively insensitivise // nested array: recursively insensitivise
insensitiveArray(val.([]interface{})) insensitiveArray(valT, normalize)
} }
return val return val
} }
func insensitiviseMap(m map[string]interface{}) { func insensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) {
for key, val := range m { for key, val := range m {
val = insensitiviseVal(val) val = insensitiviseVal(val, normalize)
lower := strings.ToLower(key) lower := normalize(key)
if key != lower { if key != lower {
// remove old key (not lower-cased) // remove old key (not lower-cased)
delete(m, key) delete(m, key)
@ -93,9 +93,9 @@ func insensitiviseMap(m map[string]interface{}) {
} }
} }
func insensitiveArray(a []interface{}) { func insensitiveArray(a []interface{}, normalize keyNormalizeHookType) {
for i, val := range a { for i, val := range a {
a[i] = insensitiviseVal(val) a[i] = insensitiviseVal(val, normalize)
} }
} }

View file

@ -14,6 +14,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings"
"testing" "testing"
"github.com/spf13/viper/internal/testutil" "github.com/spf13/viper/internal/testutil"
@ -37,7 +38,7 @@ func TestCopyAndInsensitiviseMap(t *testing.T) {
} }
) )
got := copyAndInsensitiviseMap(given) got := copyAndInsensitiviseMap(given, strings.ToLower)
if !reflect.DeepEqual(got, expected) { if !reflect.DeepEqual(got, expected) {
t.Fatalf("Got %q\nexpected\n%q", got, expected) t.Fatalf("Got %q\nexpected\n%q", got, expected)

View file

@ -142,6 +142,10 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption {
} }
} }
type keyNormalizeHookType func(string) string
var defaultKeyNormalizer = strings.ToLower
// Viper is a prioritized configuration registry. It // Viper is a prioritized configuration registry. It
// maintains a set of configuration sources, fetches // maintains a set of configuration sources, fetches
// values to populate those, and provides them according // values to populate those, and provides them according
@ -183,6 +187,10 @@ type Viper struct {
// used to access a nested value in one go // used to access a nested value in one go
keyDelim string keyDelim string
// Function to normalize keys
// by default, strings.ToLower
keyNormalizeHook keyNormalizeHookType
// A set of paths to look for the config file in // A set of paths to look for the config file in
configPaths []string configPaths []string
@ -229,6 +237,7 @@ type Viper struct {
func New() *Viper { func New() *Viper {
v := new(Viper) v := new(Viper)
v.keyDelim = "." v.keyDelim = "."
v.keyNormalizeHook = defaultKeyNormalizer
v.configName = "config" v.configName = "config"
v.configPermissions = os.FileMode(0o644) v.configPermissions = os.FileMode(0o644)
v.fs = afero.NewOsFs() v.fs = afero.NewOsFs()
@ -270,6 +279,23 @@ func KeyDelimiter(d string) Option {
}) })
} }
// KeyNormalizer is option to set arbitrary function for key normalization
// This function will be applied to all keys after unmarshal, during merge, search for duplicates, etc
// Default normalizer is strings.ToLower
func KeyNormalizer(n keyNormalizeHookType) Option {
return optionFunc(func(v *Viper) {
v.keyNormalizeHook = n
})
}
// KeyPreserveCase is option to disable key lowercasing
// By default, Viper converts all keys to lovercase
func KeyPreserveCase() Option {
return optionFunc(func(v *Viper) {
v.keyNormalizeHook = func(key string) string { return key }
})
}
// StringReplacer applies a set of replacements to a string. // StringReplacer applies a set of replacements to a string.
type StringReplacer interface { type StringReplacer interface {
// Replace returns a copy of s with all replacements performed. // Replace returns a copy of s with all replacements performed.
@ -523,6 +549,13 @@ func (v *Viper) SetEnvPrefix(in string) {
} }
} }
func (v *Viper) keyNormalize(k string) string {
if v.keyNormalizeHook != nil {
return v.keyNormalizeHook(k)
}
return defaultKeyNormalizer(k)
}
func (v *Viper) mergeWithEnvPrefix(in string) string { func (v *Viper) mergeWithEnvPrefix(in string) string {
if v.envPrefix != "" { if v.envPrefix != "" {
return strings.ToUpper(v.envPrefix + "_" + in) return strings.ToUpper(v.envPrefix + "_" + in)
@ -652,7 +685,7 @@ func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool {
// searchMap recursively searches for a value for path in source map. // searchMap recursively searches for a value for path in source map.
// Returns nil if not found. // Returns nil if not found.
// Note: This assumes that the path entries and map keys are lower cased. // Note: This assumes that the path entries and map keys are normalized (by default, lowercased).
func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} { func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} {
if len(path) == 0 { if len(path) == 0 {
return source return source
@ -691,7 +724,7 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac
// This should be useful only at config level (other maps may not contain dots // This should be useful only at config level (other maps may not contain dots
// in their keys). // in their keys).
// //
// Note: This assumes that the path entries and map keys are lower cased. // Note: This assumes that the path entries and map keys are lower cased (normalized).
func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []string) interface{} { func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []string) interface{} {
if len(path) == 0 { if len(path) == 0 {
return source return source
@ -699,7 +732,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []strin
// search for path prefixes, starting from the longest one // search for path prefixes, starting from the longest one
for i := len(path); i > 0; i-- { for i := len(path); i > 0; i-- {
prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) prefixKey := v.keyNormalize(strings.Join(path[0:i], v.keyDelim))
var val interface{} var val interface{}
switch sourceIndexable := source.(type) { switch sourceIndexable := source.(type) {
@ -890,7 +923,7 @@ func GetViper() *Viper {
func Get(key string) interface{} { return v.Get(key) } func Get(key string) interface{} { return v.Get(key) }
func (v *Viper) Get(key string) interface{} { func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key) lcaseKey := v.keyNormalize(key)
val := v.find(lcaseKey, true) val := v.find(lcaseKey, true)
if val == nil { if val == nil {
return nil return nil
@ -950,7 +983,7 @@ func (v *Viper) Sub(key string) *Viper {
} }
if reflect.TypeOf(data).Kind() == reflect.Map { if reflect.TypeOf(data).Kind() == reflect.Map {
subv.parents = append(v.parents, strings.ToLower(key)) subv.parents = append(v.parents, v.keyNormalize(key))
subv.automaticEnvApplied = v.automaticEnvApplied subv.automaticEnvApplied = v.automaticEnvApplied
subv.envPrefix = v.envPrefix subv.envPrefix = v.envPrefix
subv.envKeyReplacer = v.envKeyReplacer subv.envKeyReplacer = v.envKeyReplacer
@ -1189,7 +1222,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
if flag == nil { if flag == nil {
return fmt.Errorf("flag for %q is nil", key) return fmt.Errorf("flag for %q is nil", key)
} }
v.pflags[strings.ToLower(key)] = flag v.pflags[v.keyNormalize(key)] = flag
return nil return nil
} }
@ -1206,7 +1239,7 @@ func (v *Viper) BindEnv(input ...string) error {
return fmt.Errorf("missing key to bind to") return fmt.Errorf("missing key to bind to")
} }
key := strings.ToLower(input[0]) key := v.keyNormalize(input[0])
if len(input) == 1 { if len(input) == 1 {
v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key))
@ -1236,7 +1269,7 @@ func (v *Viper) MustBindEnv(input ...string) {
// Lastly, if no value was found and flagDefault is true, and if the key // Lastly, if no value was found and flagDefault is true, and if the key
// corresponds to a flag, the flag's default value is returned. // corresponds to a flag, the flag's default value is returned.
// //
// Note: this assumes a lower-cased key given. // Note: this assumes a normalized key given.
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
var ( var (
val interface{} val interface{}
@ -1419,12 +1452,12 @@ func stringToStringConv(val string) interface{} {
} }
// IsSet checks to see if the key has been set in any of the data locations. // IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key. // IsSet normalizes the key.
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool { func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key) normKey := v.keyNormalize(key)
val := v.find(lcaseKey, false) val := v.find(normKey, false)
return val != nil return val != nil
} }
@ -1450,11 +1483,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) {
func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) }
func (v *Viper) RegisterAlias(alias string, key string) { func (v *Viper) RegisterAlias(alias string, key string) {
v.registerAlias(alias, strings.ToLower(key)) v.registerAlias(alias, v.keyNormalize(key))
} }
func (v *Viper) registerAlias(alias string, key string) { func (v *Viper) registerAlias(alias string, key string) {
alias = strings.ToLower(alias) alias = v.keyNormalize(alias)
if alias != key && alias != v.realKey(key) { if alias != key && alias != v.realKey(key) {
_, exists := v.aliases[alias] _, exists := v.aliases[alias]
@ -1499,7 +1532,7 @@ func (v *Viper) realKey(key string) string {
func InConfig(key string) bool { return v.InConfig(key) } func InConfig(key string) bool { return v.InConfig(key) }
func (v *Viper) InConfig(key string) bool { func (v *Viper) InConfig(key string) bool {
lcaseKey := strings.ToLower(key) lcaseKey := v.keyNormalize(key)
// if the requested key is an alias, then return the proper key // if the requested key is an alias, then return the proper key
lcaseKey = v.realKey(lcaseKey) lcaseKey = v.realKey(lcaseKey)
@ -1509,17 +1542,18 @@ func (v *Viper) InConfig(key string) bool {
} }
// SetDefault sets the default value for this key. // SetDefault sets the default value for this key.
// SetDefault is case-insensitive for a key. // SetDefault applies normalization (by default, lowercases) a key.
// Default only used when no value is provided by the user via flag, config or ENV. // Default only used when no value is provided by the user via flag, config or ENV.
func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func SetDefault(key string, value interface{}) { v.SetDefault(key, value) }
func (v *Viper) SetDefault(key string, value interface{}) { func (v *Viper) SetDefault(key string, value interface{}) {
// If alias passed in, then set the proper default // If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key)) key = v.keyNormalize(key)
value = toCaseInsensitiveValue(value) value = toCaseInsensitiveValue(value, v.keyNormalize)
key = v.realKey(key)
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1527,18 +1561,19 @@ func (v *Viper) SetDefault(key string, value interface{}) {
} }
// Set sets the value for the key in the override register. // Set sets the value for the key in the override register.
// Set is case-insensitive for a key. // Set normalizes a key.
// Will be used instead of values obtained via // Will be used instead of values obtained via
// flags, config file, ENV, default, or key/value store. // flags, config file, ENV, default, or key/value store.
func Set(key string, value interface{}) { v.Set(key, value) } func Set(key string, value interface{}) { v.Set(key, value) }
func (v *Viper) Set(key string, value interface{}) { func (v *Viper) Set(key string, value interface{}) {
// If alias passed in, then set the proper override // If alias passed in, then set the proper override
key = v.realKey(strings.ToLower(key)) key = v.keyNormalize(key)
value = toCaseInsensitiveValue(value) value = toCaseInsensitiveValue(value, v.keyNormalize)
key = v.realKey(key)
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(v.override, path[0:len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1627,7 +1662,7 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error {
if v.config == nil { if v.config == nil {
v.config = make(map[string]interface{}) v.config = make(map[string]interface{})
} }
insensitiviseMap(cfg) insensitiviseMap(cfg, v.keyNormalize)
mergeMaps(cfg, v.config, nil) mergeMaps(cfg, v.config, nil)
return nil return nil
} }
@ -1727,7 +1762,8 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
} }
} }
insensitiviseMap(c) insensitiviseMap(c, v.keyNormalize)
return nil return nil
} }
@ -1750,9 +1786,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error {
} }
func keyExists(k string, m map[string]interface{}) string { func keyExists(k string, m map[string]interface{}) string {
lk := strings.ToLower(k) lk := v.keyNormalize(k)
for mk := range m { for mk := range m {
lmk := strings.ToLower(mk) lmk := v.keyNormalize(mk)
if lmk == lk { if lmk == lk {
return mk return mk
} }
@ -2031,7 +2067,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac
m2 = cast.ToStringMap(val) m2 = cast.ToStringMap(val)
default: default:
// immediate value // immediate value
shadow[strings.ToLower(fullKey)] = true shadow[v.keyNormalize(fullKey)] = true
continue continue
} }
// recursively merge to shadow map // recursively merge to shadow map
@ -2057,7 +2093,7 @@ outer:
} }
} }
// add key // add key
shadow[strings.ToLower(k)] = true shadow[v.keyNormalize(k)] = true
} }
return shadow return shadow
} }
@ -2076,7 +2112,7 @@ func (v *Viper) AllSettings() map[string]interface{} {
continue continue
} }
path := strings.Split(k, v.keyDelim) path := strings.Split(k, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(m, path[0:len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1])
// set innermost value // set innermost value
deepestMap[lastKey] = value deepestMap[lastKey] = value

View file

@ -2334,6 +2334,76 @@ func TestCaseInsensitiveSet(t *testing.T) {
} }
} }
func TestCaseSensitive(t *testing.T) {
for _, config := range []struct {
typ string
content string
}{
{"yaml", `
aBcD: 1
eF:
gH: 2
iJk: 3
Lm:
nO: 4
P:
Q: 5
R: 6
`},
{"json", `{
"aBcD": 1,
"eF": {
"iJk": 3,
"Lm": {
"P": {
"Q": 5,
"R": 6
},
"nO": 4
},
"gH": 2
}
}`},
{"toml", `aBcD = 1
[eF]
gH = 2
iJk = 3
[eF.Lm]
nO = 4
[eF.Lm.P]
Q = 5
R = 6
`},
} {
doTestCaseSensitive(t, config.typ, config.content)
}
}
func doTestCaseSensitive(t *testing.T, typ, config string) {
// Create case-sensitive instance
v := NewWithOptions(KeyPreserveCase())
v.SetConfigType(typ)
r := strings.NewReader(config)
if err := v.unmarshalReader(r, v.config); err != nil {
panic(err)
}
v.Set("RfD", true)
assert.Equal(t, nil, v.Get("rfd"))
assert.Equal(t, true, v.Get("RfD"))
assert.Equal(t, 0, cast.ToInt(v.Get("abcd")))
assert.Equal(t, 1, cast.ToInt(v.Get("aBcD")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.gh")))
assert.Equal(t, 2, cast.ToInt(v.Get("eF.gH")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.ijk")))
assert.Equal(t, 3, cast.ToInt(v.Get("eF.iJk")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.no")))
assert.Equal(t, 4, cast.ToInt(v.Get("eF.Lm.nO")))
assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.p.q")))
assert.Equal(t, 5, cast.ToInt(v.Get("eF.Lm.P.Q")))
}
func TestParseNested(t *testing.T) { func TestParseNested(t *testing.T) {
type duration struct { type duration struct {
Delay time.Duration Delay time.Duration