diff --git a/viper.go b/viper.go index d96e1d9..d2f85d8 100644 --- a/viper.go +++ b/viper.go @@ -183,6 +183,8 @@ type Viper struct { encoderRegistry EncoderRegistry decoderRegistry DecoderRegistry + decodeHook mapstructure.DecodeHookFunc + experimentalFinder bool experimentalBindStruct bool } @@ -255,6 +257,17 @@ func EnvKeyReplacer(r StringReplacer) Option { }) } +// WithDecodeHook sets a default decode hook for mapstructure. +func WithDecodeHook(h mapstructure.DecodeHookFunc) Option { + return optionFunc(func(v *Viper) { + if h == nil { + return + } + + v.decodeHook = h + }) +} + // NewWithOptions creates a new Viper instance. func NewWithOptions(opts ...Option) *Viper { v := New() @@ -900,7 +913,7 @@ func UnmarshalKey(key string, rawVal any, opts ...DecoderConfigOption) error { } func (v *Viper) UnmarshalKey(key string, rawVal any, opts ...DecoderConfigOption) error { - return decode(v.Get(key), defaultDecoderConfig(rawVal, opts...)) + return decode(v.Get(key), v.defaultDecoderConfig(rawVal, opts...)) } // Unmarshal unmarshals the config into a Struct. Make sure that the tags @@ -923,13 +936,13 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error { } // TODO: struct keys should be enough? - return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) + return decode(v.getSettings(keys), v.defaultDecoderConfig(rawVal, opts...)) } func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { var structKeyMap map[string]any - err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) + err := decode(input, v.defaultDecoderConfig(&structKeyMap, opts...)) if err != nil { return nil, err } @@ -946,15 +959,20 @@ func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]stri // defaultDecoderConfig returns default mapstructure.DecoderConfig with support // of time.Duration values & string slices. -func defaultDecoderConfig(output any, opts ...DecoderConfigOption) *mapstructure.DecoderConfig { - c := &mapstructure.DecoderConfig{ - Metadata: nil, - WeaklyTypedInput: true, - DecodeHook: mapstructure.ComposeDecodeHookFunc( +func (v *Viper) defaultDecoderConfig(output any, opts ...DecoderConfigOption) *mapstructure.DecoderConfig { + decodeHook := v.decodeHook + if decodeHook == nil { + decodeHook = mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), // mapstructure.StringToSliceHookFunc(","), stringToWeakSliceHookFunc(","), - ), + ) + } + + c := &mapstructure.DecoderConfig{ + Metadata: nil, + WeaklyTypedInput: true, + DecodeHook: decodeHook, } for _, opt := range opts { @@ -1005,7 +1023,7 @@ func UnmarshalExact(rawVal any, opts ...DecoderConfigOption) error { } func (v *Viper) UnmarshalExact(rawVal any, opts ...DecoderConfigOption) error { - config := defaultDecoderConfig(rawVal, opts...) + config := v.defaultDecoderConfig(rawVal, opts...) config.ErrorUnused = true keys := v.AllKeys() diff --git a/viper_test.go b/viper_test.go index a4df402..1419113 100644 --- a/viper_test.go +++ b/viper_test.go @@ -894,6 +894,41 @@ func TestUnmarshal(t *testing.T) { ) } +func TestUnmarshalWithDefaultDecodeHook(t *testing.T) { + opt := mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + // Custom Decode Hook Function + func(rf reflect.Kind, rt reflect.Kind, data any) (any, error) { + if rf != reflect.String || rt != reflect.Map { + return data, nil + } + m := map[string]string{} + raw := data.(string) + if raw == "" { + return m, nil + } + err := json.Unmarshal([]byte(raw), &m) + return m, err + }, + ) + + v := NewWithOptions(WithDecodeHook(opt)) + v.Set("credentials", "{\"foo\":\"bar\"}") + + type config struct { + Credentials map[string]string + } + + var C config + + require.NoError(t, v.Unmarshal(&C), "unable to decode into struct") + + assert.Equal(t, &config{ + Credentials: map[string]string{"foo": "bar"}, + }, &C) +} + func TestUnmarshalWithDecoderOptions(t *testing.T) { v := New() v.Set("credentials", "{\"foo\":\"bar\"}")