Check for nil before binding pflag(s)

* When passing nil to BindPFlag or BindPFlags, the value is set to
  a struct and passed as an interface. That struct never checks for the
  flag(set) being nil. Thus, it makes sense to check before it's set to
  the struct.
* fixes #422
This commit is contained in:
Harley Laue 2018-04-23 10:24:03 -07:00
parent 8dc2790b02
commit a7a5948b15
2 changed files with 22 additions and 0 deletions

View file

@ -811,6 +811,9 @@ func (v *Viper) UnmarshalExact(rawVal interface{}) error {
// name as the config key. // name as the config key.
func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) } func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) }
func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { func (v *Viper) BindPFlags(flags *pflag.FlagSet) error {
if flags == nil {
return fmt.Errorf("FlagSet cannot be nil")
}
return v.BindFlagValues(pflagValueSet{flags}) return v.BindFlagValues(pflagValueSet{flags})
} }
@ -822,6 +825,9 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) error {
// //
func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) } func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) }
func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error { func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error {
if flag == nil {
return fmt.Errorf("flag for %q is nil", key)
}
return v.BindFlagValue(key, pflagValue{flag}) return v.BindFlagValue(key, pflagValue{flag})
} }

View file

@ -577,6 +577,14 @@ func TestBindPFlagsStringSlice(t *testing.T) {
} }
} }
func TestBindPFlagsNil(t *testing.T) {
v := New()
err := v.BindPFlags(nil)
if err == nil {
t.Fatalf("expected error when passing nil to BindPFlags")
}
}
func TestBindPFlag(t *testing.T) { func TestBindPFlag(t *testing.T) {
var testString = "testing" var testString = "testing"
var testValue = newStringValue(testString, &testString) var testValue = newStringValue(testString, &testString)
@ -598,6 +606,14 @@ func TestBindPFlag(t *testing.T) {
} }
func TestBindPFlagNil(t *testing.T) {
v := New()
err := v.BindPFlag("any", nil)
if err == nil {
t.Fatalf("expected error when passing nil to BindPFlag")
}
}
func TestBoundCaseSensitivity(t *testing.T) { func TestBoundCaseSensitivity(t *testing.T) {
assert.Equal(t, "brown", Get("eyes")) assert.Equal(t, "brown", Get("eyes"))