feat: allow setting a default decode hook

Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
This commit is contained in:
Mark Sagi-Kazar 2024-06-24 18:20:56 +02:00 committed by Márk Sági-Kazár
parent 5964efa262
commit d2458a2abc
2 changed files with 63 additions and 10 deletions

View file

@ -183,6 +183,8 @@ type Viper struct {
encoderRegistry EncoderRegistry encoderRegistry EncoderRegistry
decoderRegistry DecoderRegistry decoderRegistry DecoderRegistry
decodeHook mapstructure.DecodeHookFunc
experimentalFinder bool experimentalFinder bool
experimentalBindStruct bool experimentalBindStruct bool
} }
@ -255,6 +257,17 @@ func EnvKeyReplacer(r StringReplacer) Option {
}) })
} }
// WithDecodeHook sets a default decode hook for mapstructure.
func WithDecodeHook(h mapstructure.DecodeHookFunc) Option {
return optionFunc(func(v *Viper) {
if h == nil {
return
}
v.decodeHook = h
})
}
// NewWithOptions creates a new Viper instance. // NewWithOptions creates a new Viper instance.
func NewWithOptions(opts ...Option) *Viper { func NewWithOptions(opts ...Option) *Viper {
v := New() v := New()
@ -900,7 +913,7 @@ func UnmarshalKey(key string, rawVal any, opts ...DecoderConfigOption) error {
} }
func (v *Viper) UnmarshalKey(key string, rawVal any, opts ...DecoderConfigOption) error { func (v *Viper) UnmarshalKey(key string, rawVal any, opts ...DecoderConfigOption) error {
return decode(v.Get(key), defaultDecoderConfig(rawVal, opts...)) return decode(v.Get(key), v.defaultDecoderConfig(rawVal, opts...))
} }
// Unmarshal unmarshals the config into a Struct. Make sure that the tags // Unmarshal unmarshals the config into a Struct. Make sure that the tags
@ -923,13 +936,13 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error {
} }
// TODO: struct keys should be enough? // TODO: struct keys should be enough?
return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) return decode(v.getSettings(keys), v.defaultDecoderConfig(rawVal, opts...))
} }
func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) {
var structKeyMap map[string]any var structKeyMap map[string]any
err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) err := decode(input, v.defaultDecoderConfig(&structKeyMap, opts...))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -946,15 +959,20 @@ func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]stri
// defaultDecoderConfig returns default mapstructure.DecoderConfig with support // defaultDecoderConfig returns default mapstructure.DecoderConfig with support
// of time.Duration values & string slices. // of time.Duration values & string slices.
func defaultDecoderConfig(output any, opts ...DecoderConfigOption) *mapstructure.DecoderConfig { func (v *Viper) defaultDecoderConfig(output any, opts ...DecoderConfigOption) *mapstructure.DecoderConfig {
c := &mapstructure.DecoderConfig{ decodeHook := v.decodeHook
Metadata: nil, if decodeHook == nil {
WeaklyTypedInput: true, decodeHook = mapstructure.ComposeDecodeHookFunc(
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeDurationHookFunc(),
// mapstructure.StringToSliceHookFunc(","), // mapstructure.StringToSliceHookFunc(","),
stringToWeakSliceHookFunc(","), stringToWeakSliceHookFunc(","),
), )
}
c := &mapstructure.DecoderConfig{
Metadata: nil,
WeaklyTypedInput: true,
DecodeHook: decodeHook,
} }
for _, opt := range opts { for _, opt := range opts {
@ -1005,7 +1023,7 @@ func UnmarshalExact(rawVal any, opts ...DecoderConfigOption) error {
} }
func (v *Viper) UnmarshalExact(rawVal any, opts ...DecoderConfigOption) error { func (v *Viper) UnmarshalExact(rawVal any, opts ...DecoderConfigOption) error {
config := defaultDecoderConfig(rawVal, opts...) config := v.defaultDecoderConfig(rawVal, opts...)
config.ErrorUnused = true config.ErrorUnused = true
keys := v.AllKeys() keys := v.AllKeys()

View file

@ -894,6 +894,41 @@ func TestUnmarshal(t *testing.T) {
) )
} }
func TestUnmarshalWithDefaultDecodeHook(t *testing.T) {
opt := mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
// Custom Decode Hook Function
func(rf reflect.Kind, rt reflect.Kind, data any) (any, error) {
if rf != reflect.String || rt != reflect.Map {
return data, nil
}
m := map[string]string{}
raw := data.(string)
if raw == "" {
return m, nil
}
err := json.Unmarshal([]byte(raw), &m)
return m, err
},
)
v := NewWithOptions(WithDecodeHook(opt))
v.Set("credentials", "{\"foo\":\"bar\"}")
type config struct {
Credentials map[string]string
}
var C config
require.NoError(t, v.Unmarshal(&C), "unable to decode into struct")
assert.Equal(t, &config{
Credentials: map[string]string{"foo": "bar"},
}, &C)
}
func TestUnmarshalWithDecoderOptions(t *testing.T) { func TestUnmarshalWithDecoderOptions(t *testing.T) {
v := New() v := New()
v.Set("credentials", "{\"foo\":\"bar\"}") v.Set("credentials", "{\"foo\":\"bar\"}")