diff --git a/flags_test.go b/flags_test.go index 0b976b6..1dd7b0d 100644 --- a/flags_test.go +++ b/flags_test.go @@ -45,7 +45,7 @@ func TestBindFlagValueSet(t *testing.T) { func TestBindFlagValue(t *testing.T) { var testString = "testing" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "testflag", diff --git a/viper.go b/viper.go index c166e9f..104a657 100644 --- a/viper.go +++ b/viper.go @@ -612,7 +612,7 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey) + val := v.find(lcaseKey, true) if val == nil { return nil } @@ -885,9 +885,12 @@ func (v *Viper) BindEnv(input ...string) error { // Given a key, find the value. // Viper will check in the following order: // flag, env, config file, key/value store, default. -// Viper will check to see if an alias exists first. -// Note: this assumes a lower-cased key given. -func (v *Viper) find(lcaseKey string) interface{} { +// Viper will then check in the following order: +// flag, env, config file, key/value store. +// Lastly, if no value was found and flagDefault is true, and if the key +// corresponds to a flag, the flag's default value is returned. +// +func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { var ( val interface{} @@ -1003,6 +1006,18 @@ func (v *Viper) find(lcaseKey string) interface{} { } // last item, no need to check shadowing + // it could also be a key prefix, search for that prefix to get the values from + // pflags that match it + sub := make(map[string]interface{}) + for key, val := range v.pflags { + if flagDefault && strings.HasPrefix(key, lcaseKey) { + sub[strings.TrimPrefix(key, lcaseKey+".")] = val.ValueString() + } + } + if len(sub) != 0 { + return sub + } + return nil } @@ -1020,7 +1035,7 @@ func readAsCSV(val string) ([]string, error) { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey) + val := v.find(lcaseKey, false) return val != nil } diff --git a/viper_test.go b/viper_test.go index 443345e..84326f5 100644 --- a/viper_test.go +++ b/viper_test.go @@ -224,9 +224,8 @@ func initDirs(t *testing.T) (string, string, func()) { //stubs for PFlag Values type stringValue string -func newStringValue(val string, p *string) *stringValue { - *p = val - return (*stringValue)(p) +func newStringValue(val string) *stringValue { + return (*stringValue)(&val) } func (s *stringValue) Set(val string) error { @@ -587,7 +586,7 @@ func TestBindPFlagsNil(t *testing.T) { func TestBindPFlag(t *testing.T) { var testString = "testing" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "testflag", @@ -623,7 +622,7 @@ func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "blue", Get("eyes")) var testString = "green" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "eyeballs", @@ -864,6 +863,31 @@ func TestSub(t *testing.T) { assert.Equal(t, (*Viper)(nil), subv) } +func TestSubPflags(t *testing.T) { + v := New() + + // same as yamlExample, without hobbies + v.BindPFlag("name", &pflag.Flag{Value: newStringValue("steve"), Changed: true}) + v.BindPFlag("clothing.jacket", &pflag.Flag{Value: newStringValue("leather"), Changed: true}) + v.BindPFlag("clothing.trousers", &pflag.Flag{Value: newStringValue("denim"), Changed: true}) + v.BindPFlag("clothing.pants.size", &pflag.Flag{Value: newStringValue("large"), Changed: true}) + v.BindPFlag("age", &pflag.Flag{Value: newStringValue("35"), Changed: true}) + v.BindPFlag("eyes", &pflag.Flag{Value: newStringValue("brown"), Changed: true}) + v.BindPFlag("beard", &pflag.Flag{Value: newStringValue("yes"), Changed: true}) + + subv := v.Sub("clothing") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + + subv = v.Sub("clothing.pants") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) + + subv = v.Sub("clothing.pants.size") + assert.Equal(t, (*Viper)(nil), subv) + + subv = v.Sub("missing.key") + assert.Equal(t, (*Viper)(nil), subv) +} + var hclWriteExpected = []byte(`"foos" = { "foo" = { "key" = 1