diff --git a/encoding.go b/encoding.go index 7d48d0d..2563348 100644 --- a/encoding.go +++ b/encoding.go @@ -28,16 +28,32 @@ type Codec interface { Decoder } +type encodingError string + +func (e encodingError) Error() string { + return string(e) +} + +const ( + // ErrEncoderNotFound is returned when there is no encoder registered for a format. + ErrEncoderNotFound = encodingError("encoder not found for this format") + + // ErrDecoderNotFound is returned when there is no decoder registered for a format. + ErrDecoderNotFound = encodingError("decoder not found for this format") +) + // EncoderRegistry returns an [Encoder] for a given format. -// The second return value is false if no [Encoder] is registered for the format. +// +// The error is [ErrEncoderNotFound] if no [Encoder] is registered for the format. type EncoderRegistry interface { - Encoder(format string) (Encoder, bool) + Encoder(format string) (Encoder, error) } // DecoderRegistry returns an [Decoder] for a given format. -// The second return value is false if no [Decoder] is registered for the format. +// +// The error is [ErrDecoderNotFound] if no [Decoder] is registered for the format. type DecoderRegistry interface { - Decoder(format string) (Decoder, bool) + Decoder(format string) (Decoder, error) } // [CodecRegistry] combines [EncoderRegistry] and [DecoderRegistry] interfaces. @@ -72,12 +88,22 @@ type codecRegistry struct { v *Viper } -func (r codecRegistry) Encoder(format string) (Encoder, bool) { - return r.codec(format) +func (r codecRegistry) Encoder(format string) (Encoder, error) { + encoder, ok := r.codec(format) + if !ok { + return nil, ErrEncoderNotFound + } + + return encoder, nil } -func (r codecRegistry) Decoder(format string) (Decoder, bool) { - return r.codec(format) +func (r codecRegistry) Decoder(format string) (Decoder, error) { + decoder, ok := r.codec(format) + if !ok { + return nil, ErrDecoderNotFound + } + + return decoder, nil } func (r codecRegistry) codec(format string) (Codec, bool) { diff --git a/viper.go b/viper.go index 151ec2d..89b6780 100644 --- a/viper.go +++ b/viper.go @@ -1723,12 +1723,12 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]any) error { switch format := strings.ToLower(v.getConfigType()); format { case "yaml", "yml", "json", "toml", "hcl", "tfvars", "ini", "properties", "props", "prop", "dotenv", "env": - decoder, ok := v.decoderRegistry2.Decoder(format) - if !ok { - return ConfigParseError{errors.New("decoder not found")} + decoder, err := v.decoderRegistry2.Decoder(format) + if err != nil { + return ConfigParseError{err} } - err := decoder.Decode(buf.Bytes(), c) + err = decoder.Decode(buf.Bytes(), c) if err != nil { return ConfigParseError{err} } @@ -1743,10 +1743,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error { c := v.AllSettings() switch configType { case "yaml", "yml", "json", "toml", "hcl", "tfvars", "ini", "prop", "props", "properties", "dotenv", "env": - encoder, ok := v.encoderRegistry2.Encoder(configType) - if !ok { - // TODO: return a proper error - return ConfigMarshalError{errors.New("encoder not found")} + encoder, err := v.encoderRegistry2.Encoder(configType) + if err != nil { + return ConfigMarshalError{err} } b, err := encoder.Encode(c)