diff --git a/viper.go b/viper.go index 7eac4b7..d519e17 100644 --- a/viper.go +++ b/viper.go @@ -1289,7 +1289,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { return cast.ToDurationSlice(slice) case "stringToString": return stringToStringConv(flag.ValueString()) - + case "stringToInt": + return stringToIntConv(flag.ValueString()) default: return flag.ValueString() } @@ -1370,6 +1371,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { return cast.ToIntSlice(res) case "stringToString": return stringToStringConv(flag.ValueString()) + case "stringToInt": + return stringToIntConv(flag.ValueString()) case "durationSlice": s := strings.TrimPrefix(flag.ValueString(), "[") s = strings.TrimSuffix(s, "]") @@ -1418,6 +1421,30 @@ func stringToStringConv(val string) interface{} { return out } +// mostly copied from pflag's implementation of this operation here https://github.com/spf13/pflag/blob/d5e0c0615acee7028e1e2740a11102313be88de1/string_to_int.go#L68 +// alterations are: errors are swallowed, map[string]interface{} is returned in order to enable cast.ToStringMap +func stringToIntConv(val string) interface{} { + val = strings.Trim(val, "[]") + // An empty string would cause an empty map + if len(val) == 0 { + return map[string]interface{}{} + } + ss := strings.Split(val, ",") + out := make(map[string]interface{}, len(ss)) + for _, pair := range ss { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return nil + } + var err error + out[kv[0]], err = strconv.Atoi(kv[1]) + if err != nil { + return nil + } + } + return out +} + // IsSet checks to see if the key has been set in any of the data locations. // IsSet is case-insensitive for a key. func IsSet(key string) bool { return v.IsSet(key) } diff --git a/viper_test.go b/viper_test.go index 8283b5c..b48c950 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1255,6 +1255,53 @@ func TestBindPFlagStringToString(t *testing.T) { } } +func TestBindPFlagStringToInt(t *testing.T) { + tests := []struct { + Expected map[string]int + Value string + }{ + {map[string]int{"yo": 1, "oh": 21}, "yo=1,oh=21"}, + {map[string]int{"yo": 100000000, "oh": 0}, "yo=100000000,oh=0"}, + {map[string]int{}, "yo=2,oh=21.0"}, + {map[string]int{}, "yo=,oh=20.99"}, + {map[string]int{}, "yo=,oh="}, + } + + v := New() // create independent Viper object + defaultVal := map[string]int{} + v.SetDefault("stringtoint", defaultVal) + + for _, testValue := range tests { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.StringToInt("stringtoint", testValue.Expected, "test") + + for _, changed := range []bool{true, false} { + flagSet.VisitAll(func(f *pflag.Flag) { + f.Value.Set(testValue.Value) + f.Changed = changed + }) + + err := v.BindPFlags(flagSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + type TestMap struct { + StringToInt map[string]int + } + val := &TestMap{} + if err := v.Unmarshal(val); err != nil { + t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) + } + if changed { + assert.Equal(t, testValue.Expected, val.StringToInt) + } else { + assert.Equal(t, defaultVal, val.StringToInt) + } + } + } +} + func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "brown", Get("eyes"))