Properly handle string slice values

This commit is contained in:
Paweł Szczur 2017-04-17 10:08:15 +02:00 committed by Bjørn Erik Pedersen
parent 5d46e70da8
commit 0967fc9ace
2 changed files with 54 additions and 2 deletions

View file

@ -21,6 +21,7 @@ package viper
import ( import (
"bytes" "bytes"
"encoding/csv"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -894,7 +895,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -961,7 +964,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -971,6 +976,15 @@ func (v *Viper) find(lcaseKey string) interface{} {
return nil return nil
} }
func readAsCSV(val string) ([]string, error) {
if val == "" {
return []string{}, nil
}
stringReader := strings.NewReader(val)
csvReader := csv.NewReader(stringReader)
return csvReader.Read()
}
// IsSet checks to see if the key has been set in any of the data locations. // IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key. // IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }

View file

@ -538,6 +538,44 @@ func TestBindPFlags(t *testing.T) {
} }
func TestBindPFlagsStringSlice(t *testing.T) {
for _, testValue := range []struct {
Expected []string
Value string
}{
{[]string{}, ""},
{[]string{"jeden"}, "jeden"},
{[]string{"dwa", "trzy"}, "dwa,trzy"},
{[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} {
for _, changed := range []bool{true, false} {
v := New() // create independent Viper object
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
flagSet.StringSlice("stringslice", testValue.Expected, "test")
flagSet.Visit(func(f *pflag.Flag) {
if len(testValue.Value) > 0 {
f.Value.Set(testValue.Value)
f.Changed = changed
}
})
err := v.BindPFlags(flagSet)
if err != nil {
t.Fatalf("error binding flag set, %v", err)
}
type TestStr struct {
StringSlice []string
}
val := &TestStr{}
if err := v.Unmarshal(val); err != nil {
t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err)
}
assert.Equal(t, testValue.Expected, val.StringSlice)
}
}
}
func TestBindPFlag(t *testing.T) { func TestBindPFlag(t *testing.T) {
var testString = "testing" var testString = "testing"
var testValue = newStringValue(testString, &testString) var testValue = newStringValue(testString, &testString)