From ad5ed02fa46d33a367ae1e563a3579f1a42e867e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rk=20S=C3=A1gi-Kaz=C3=A1r?= Date: Tue, 11 Jun 2019 22:51:57 +0200 Subject: [PATCH] Add support for int slice flags (#637) * Add support for int slice flags * Add int slice test to unmarshal --- viper.go | 12 +++++++++ viper_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/viper.go b/viper.go index a7fa988..0699c4c 100644 --- a/viper.go +++ b/viper.go @@ -706,6 +706,8 @@ func (v *Viper) Get(key string) interface{} { return cast.ToDuration(val) case []string: return cast.ToStringSlice(val) + case []int: + return cast.ToIntSlice(val) } } @@ -1013,6 +1015,11 @@ func (v *Viper) find(lcaseKey string) interface{} { s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) return res + case "intSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return cast.ToIntSlice(res) default: return flag.ValueString() } @@ -1082,6 +1089,11 @@ func (v *Viper) find(lcaseKey string) interface{} { s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) return res + case "intSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return cast.ToIntSlice(res) default: return flag.ValueString() } diff --git a/viper_test.go b/viper_test.go index 60d912a..c40c971 100644 --- a/viper_test.go +++ b/viper_test.go @@ -540,11 +540,13 @@ func TestUnmarshal(t *testing.T) { SetDefault("port", 1313) Set("name", "Steve") Set("duration", "1s1ms") + Set("modes", []int{1, 2, 3}) type config struct { Port int Name string Duration time.Duration + Modes []int } var C config @@ -554,14 +556,33 @@ func TestUnmarshal(t *testing.T) { t.Fatalf("unable to decode into struct, %v", err) } - assert.Equal(t, &config{Name: "Steve", Port: 1313, Duration: time.Second + time.Millisecond}, &C) + assert.Equal( + t, + &config{ + Name: "Steve", + Port: 1313, + Duration: time.Second + time.Millisecond, + Modes: []int{1, 2, 3}, + }, + &C, + ) Set("port", 1234) err = Unmarshal(&C) if err != nil { t.Fatalf("unable to decode into struct, %v", err) } - assert.Equal(t, &config{Name: "Steve", Port: 1234, Duration: time.Second + time.Millisecond}, &C) + + assert.Equal( + t, + &config{ + Name: "Steve", + Port: 1234, + Duration: time.Second + time.Millisecond, + Modes: []int{1, 2, 3}, + }, + &C, + ) } func TestUnmarshalWithDecoderOptions(t *testing.T) { @@ -682,6 +703,51 @@ func TestBindPFlagsStringSlice(t *testing.T) { } } +func TestBindPFlagsIntSlice(t *testing.T) { + tests := []struct { + Expected []int + Value string + }{ + {nil, ""}, + {[]int{1}, "1"}, + {[]int{2, 3}, "2,3"}, + } + + v := New() // create independent Viper object + defaultVal := []int{0} + v.SetDefault("intslice", defaultVal) + + for _, testValue := range tests { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.IntSlice("intslice", testValue.Expected, "test") + + for _, changed := range []bool{true, false} { + flagSet.VisitAll(func(f *pflag.Flag) { + f.Value.Set(testValue.Value) + f.Changed = changed + }) + + err := v.BindPFlags(flagSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + type TestInt struct { + IntSlice []int + } + val := &TestInt{} + if err := v.Unmarshal(val); err != nil { + t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) + } + if changed { + assert.Equal(t, testValue.Expected, val.IntSlice) + } else { + assert.Equal(t, defaultVal, val.IntSlice) + } + } + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString)