Flag Binding Refactor

This patch alters the way flags are handled to coincide with the
documentation on the Viper README. The documentation indicated that flag
bindings were late, when in fact they were very, very early. This patch
changes flag bindings to behave as late bindings.
This commit is contained in:
akutz 2015-11-09 16:58:46 -06:00
parent a7ef020a9a
commit 690c1c1ef0

View file

@ -405,14 +405,32 @@ func (v *Viper) Get(key string) interface{} {
if val == nil {
source := v.find(path[0])
if source == nil {
return nil
}
if source != nil {
if reflect.TypeOf(source).Kind() == reflect.Map {
val = v.searchMap(cast.ToStringMap(source), path[1:])
}
}
}
// if no other value is returned and a flag does exist for the value,
// get the flag's value even if the flag's value has not changed
if val == nil {
if flag, exists := v.pflags[lcaseKey]; exists {
jww.TRACE.Println(key, "get pflag default", val)
switch flag.Value.Type() {
case "int", "int8", "int16", "int32", "int64":
val = cast.ToInt(flag.Value.String())
case "bool":
val = cast.ToBool(flag.Value.String())
default:
val = flag.Value.String()
}
}
}
if val == nil {
return nil
}
var valType interface{}
if !v.typeByDefValue {
@ -539,22 +557,11 @@ func (v *Viper) Unmarshal(rawVal interface{}) error {
func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) }
func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) {
flags.VisitAll(func(flag *pflag.Flag) {
if err != nil {
// an error has been encountered in one of the previous flags
if err = v.BindPFlag(flag.Name, flag); err != nil {
return
}
err = v.BindPFlag(flag.Name, flag)
switch flag.Value.Type() {
case "int", "int8", "int16", "int32", "int64":
v.SetDefault(flag.Name, cast.ToInt(flag.Value.String()))
case "bool":
v.SetDefault(flag.Name, cast.ToBool(flag.Value.String()))
default:
v.SetDefault(flag.Name, flag.Value.String())
}
})
return
return nil
}
// Bind a specific key to a flag (as used by cobra)
@ -569,15 +576,6 @@ func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) {
return fmt.Errorf("flag for %q is nil", key)
}
v.pflags[strings.ToLower(key)] = flag
switch flag.Value.Type() {
case "int", "int8", "int16", "int32", "int64":
v.SetDefault(key, cast.ToInt(flag.Value.String()))
case "bool":
v.SetDefault(key, cast.ToBool(flag.Value.String()))
default:
v.SetDefault(key, flag.Value.String())
}
return nil
}
@ -618,9 +616,14 @@ func (v *Viper) find(key string) interface{} {
// PFlag Override first
flag, exists := v.pflags[key]
if exists {
if flag.Changed {
jww.TRACE.Println(key, "found in override (via pflag):", val)
if exists && flag.Changed {
jww.TRACE.Println(key, "found in override (via pflag):", flag.Value)
switch flag.Value.Type() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.Value.String())
case "bool":
return cast.ToBool(flag.Value.String())
default:
return flag.Value.String()
}
}