From a62e20b9ad407a2330e74404770eeb81fce8080b Mon Sep 17 00:00:00 2001 From: Nothin Date: Sat, 15 Apr 2023 14:58:25 +0800 Subject: [PATCH] [FEATURE] viper add MapTo to quick map to struct or base type --- README.md | 36 +++++++ internal/convert/convert.go | 166 +++++++++++++++++++++++++++++++ internal/convert/convert_test.go | 108 ++++++++++++++++++++ viper_convert.go | 31 ++++++ viper_test.go | 19 ++++ 5 files changed, 360 insertions(+) create mode 100644 internal/convert/convert.go create mode 100644 internal/convert/convert_test.go create mode 100644 viper_convert.go diff --git a/README.md b/README.md index c86b9b7..e398dab 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,42 @@ if err := viper.ReadInConfig(); err != nil { *NOTE [since 1.6]:* You can also have a file without an extension and specify the format programmaticaly. For those configuration files that lie in the home of the user without any extension like `.bashrc` +### MapTo +- source file +```yaml +service: + port: 1234 + ip: "127.0.0.1" +version: 1.0.01 +``` +- use MapTo + +```go +type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` +} +//your prepare code ... + +var service Service +var version string + +if err := viper.MapTo("service",&service); err != nil { + //error handler... +} +if err := viper.MapTo("version",&version); err != nil { + //error handler... +} + + + +log.Println(service,version) + +//.... + +``` + + ### Writing Config Files Reading from config files is useful, but at times you want to store all modifications made at run time. diff --git a/internal/convert/convert.go b/internal/convert/convert.go new file mode 100644 index 0000000..415c158 --- /dev/null +++ b/internal/convert/convert.go @@ -0,0 +1,166 @@ +package convert + +import ( + "fmt" + "reflect" + "strings" +) + +var convertUtils = map[reflect.Kind]func(reflect.Value, reflect.Value) error{ + reflect.String: converNormal, + reflect.Int: converNormal, + reflect.Int16: converNormal, + reflect.Int32: converNormal, + reflect.Int64: converNormal, + reflect.Uint: converNormal, + reflect.Uint16: converNormal, + reflect.Uint32: converNormal, + reflect.Uint64: converNormal, + reflect.Float32: converNormal, + reflect.Float64: converNormal, + reflect.Uint8: converNormal, + reflect.Int8: converNormal, +} + +//Convert +//示例 +/* + type Target struct { + A int `viper:"aint"` + B string `viper:"bstr"` + } + src :=map[string]interface{}{ + "aint":1224, + "bstr":"124132" + } + + var t Target + Convert(src,&t) + +*/ +//fix循环引用的问题 +var _ = func() struct{} { + convertUtils[reflect.Map] = convertMap + convertUtils[reflect.Array] = convertSlice + convertUtils[reflect.Slice] = convertSlice + return struct{}{} +}() + +func Convert(src interface{}, dst interface{}) (err error) { + + dstRef := reflect.ValueOf(dst) + if dstRef.Kind() != reflect.Ptr { + return fmt.Errorf("dst is not ptr") + } + + dstRef = reflect.Indirect(dstRef) + + srcRef := reflect.ValueOf(src) + if srcRef.Kind() == reflect.Ptr || srcRef.Kind() == reflect.Interface { + srcRef = srcRef.Elem() + } + if f, ok := convertUtils[srcRef.Kind()]; ok { + return f(srcRef, dstRef) + } + + return fmt.Errorf("no implemented:%s", srcRef.Type()) +} + +func converNormal(src reflect.Value, dst reflect.Value) error { + if dst.CanSet() { + if src.Type() == dst.Type() { + dst.Set(src) + } else if src.CanConvert(dst.Type()) { + dst.Set(src.Convert(dst.Type())) + } else { + return fmt.Errorf("can not convert:%s:%s", src.Type().String(), dst.Type().String()) + } + } + return nil +} + +func convertSlice(src reflect.Value, dst reflect.Value) error { + if dst.Kind() != reflect.Array && dst.Kind() != reflect.Slice { + return fmt.Errorf("error type:%s", dst.Type().String()) + } + l := src.Len() + target := reflect.MakeSlice(dst.Type(), l, l) + if dst.CanSet() { + dst.Set(target) + } + for i := 0; i < l; i++ { + srcValue := src.Index(i) + if srcValue.Kind() == reflect.Ptr || srcValue.Kind() == reflect.Interface { + srcValue = srcValue.Elem() + } + if f, ok := convertUtils[srcValue.Kind()]; ok { + err := f(srcValue, dst.Index(i)) + if err != nil { + return err + } + } + } + + return nil +} + +func convertMap(src reflect.Value, dst reflect.Value) error { + if src.Kind() != reflect.Map || dst.Kind() != reflect.Struct { + if src.Kind() == reflect.Interface { + return convertMap(src.Elem(), dst) + } else { + return fmt.Errorf("src or dst type error,%s,%s", src.Type().String(), dst.Type().String()) + } + } + dstType := dst.Type() + num := dstType.NumField() + exist := map[string]int{} + for i := 0; i < num; i++ { + k := dstType.Field(i).Tag.Get("viper") + if k == "" { + k = dstType.Field(i).Name + } + if strings.Contains(k, ",") { + taglist := strings.Split(k, ",") + if taglist[0] == "" { + + k = dstType.Field(i).Name + } else { + k = taglist[0] + + } + + } + exist[k] = i + } + + keys := src.MapKeys() + for _, key := range keys { + if index, ok := exist[key.String()]; ok { + v := dst.Field(index) + if v.Kind() == reflect.Struct { + err := convertMap(src.MapIndex(key), v) + if err != nil { + return err + } + } else { + if v.CanSet() { + if v.Type() == src.MapIndex(key).Elem().Type() { + v.Set(src.MapIndex(key).Elem()) + } else if src.MapIndex(key).Elem().CanConvert(v.Type()) { + v.Set(src.MapIndex(key).Elem().Convert(v.Type())) + } else if f, ok := convertUtils[src.MapIndex(key).Elem().Kind()]; ok && f != nil { + err := f(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + } else { + return fmt.Errorf("error type:d(%s)s(%s)", v.Type(), src.Type()) + } + } + } + } + } + + return nil +} diff --git a/internal/convert/convert_test.go b/internal/convert/convert_test.go new file mode 100644 index 0000000..05c1ea3 --- /dev/null +++ b/internal/convert/convert_test.go @@ -0,0 +1,108 @@ +package convert + +import ( + "testing" +) + +func TestConvert(t *testing.T) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + t.Error(err) + } + // t.Error(tmp) + +} + +func BenchmarkConvert(b *testing.B) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + for i := 0; i < b.N; i++ { + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + b.Error(err) + } + } +} diff --git a/viper_convert.go b/viper_convert.go new file mode 100644 index 0000000..71b2197 --- /dev/null +++ b/viper_convert.go @@ -0,0 +1,31 @@ +package viper + +import "github.com/spf13/viper/internal/convert" + +//MapTo quick map to struct if know what the value carries +//using `viper:"key"`` tag to specify keys +/* + EG: + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + + var service Service + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) +*/ +func MapTo(key string, target interface{}) error { + return v.MapTo(key, target) +} + +func (v *Viper) MapTo(key string, target interface{}) error { + return convert.Convert(v.Get(key), target) +} diff --git a/viper_test.go b/viper_test.go index 8283b5c..5596724 100644 --- a/viper_test.go +++ b/viper_test.go @@ -503,6 +503,25 @@ func TestDefault(t *testing.T) { assert.Equal(t, "leather", Get("clothing.jacket")) } +func TestMapTo(t *testing.T) { + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + + var service Service + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) + +} + func TestUnmarshaling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample)