mirror of
https://github.com/spf13/viper
synced 2024-12-22 19:47:01 +00:00
Flag Binding Refactor
This patch alters the way flags are handled to coincide with the documentation on the Viper README. The documentation indicated that flag bindings were late, when in fact they were very, very early. This patch changes flag bindings to behave as late bindings.
This commit is contained in:
parent
a7ef020a9a
commit
690c1c1ef0
1 changed files with 32 additions and 29 deletions
61
viper.go
61
viper.go
|
@ -405,15 +405,33 @@ func (v *Viper) Get(key string) interface{} {
|
||||||
|
|
||||||
if val == nil {
|
if val == nil {
|
||||||
source := v.find(path[0])
|
source := v.find(path[0])
|
||||||
if source == nil {
|
if source != nil {
|
||||||
return nil
|
if reflect.TypeOf(source).Kind() == reflect.Map {
|
||||||
|
val = v.searchMap(cast.ToStringMap(source), path[1:])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if reflect.TypeOf(source).Kind() == reflect.Map {
|
// if no other value is returned and a flag does exist for the value,
|
||||||
val = v.searchMap(cast.ToStringMap(source), path[1:])
|
// get the flag's value even if the flag's value has not changed
|
||||||
|
if val == nil {
|
||||||
|
if flag, exists := v.pflags[lcaseKey]; exists {
|
||||||
|
jww.TRACE.Println(key, "get pflag default", val)
|
||||||
|
switch flag.Value.Type() {
|
||||||
|
case "int", "int8", "int16", "int32", "int64":
|
||||||
|
val = cast.ToInt(flag.Value.String())
|
||||||
|
case "bool":
|
||||||
|
val = cast.ToBool(flag.Value.String())
|
||||||
|
default:
|
||||||
|
val = flag.Value.String()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if val == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var valType interface{}
|
var valType interface{}
|
||||||
if !v.typeByDefValue {
|
if !v.typeByDefValue {
|
||||||
valType = val
|
valType = val
|
||||||
|
@ -539,22 +557,11 @@ func (v *Viper) Unmarshal(rawVal interface{}) error {
|
||||||
func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) }
|
func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) }
|
||||||
func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) {
|
func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) {
|
||||||
flags.VisitAll(func(flag *pflag.Flag) {
|
flags.VisitAll(func(flag *pflag.Flag) {
|
||||||
if err != nil {
|
if err = v.BindPFlag(flag.Name, flag); err != nil {
|
||||||
// an error has been encountered in one of the previous flags
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.BindPFlag(flag.Name, flag)
|
|
||||||
switch flag.Value.Type() {
|
|
||||||
case "int", "int8", "int16", "int32", "int64":
|
|
||||||
v.SetDefault(flag.Name, cast.ToInt(flag.Value.String()))
|
|
||||||
case "bool":
|
|
||||||
v.SetDefault(flag.Name, cast.ToBool(flag.Value.String()))
|
|
||||||
default:
|
|
||||||
v.SetDefault(flag.Name, flag.Value.String())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind a specific key to a flag (as used by cobra)
|
// Bind a specific key to a flag (as used by cobra)
|
||||||
|
@ -569,15 +576,6 @@ func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) {
|
||||||
return fmt.Errorf("flag for %q is nil", key)
|
return fmt.Errorf("flag for %q is nil", key)
|
||||||
}
|
}
|
||||||
v.pflags[strings.ToLower(key)] = flag
|
v.pflags[strings.ToLower(key)] = flag
|
||||||
|
|
||||||
switch flag.Value.Type() {
|
|
||||||
case "int", "int8", "int16", "int32", "int64":
|
|
||||||
v.SetDefault(key, cast.ToInt(flag.Value.String()))
|
|
||||||
case "bool":
|
|
||||||
v.SetDefault(key, cast.ToBool(flag.Value.String()))
|
|
||||||
default:
|
|
||||||
v.SetDefault(key, flag.Value.String())
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -618,9 +616,14 @@ func (v *Viper) find(key string) interface{} {
|
||||||
|
|
||||||
// PFlag Override first
|
// PFlag Override first
|
||||||
flag, exists := v.pflags[key]
|
flag, exists := v.pflags[key]
|
||||||
if exists {
|
if exists && flag.Changed {
|
||||||
if flag.Changed {
|
jww.TRACE.Println(key, "found in override (via pflag):", flag.Value)
|
||||||
jww.TRACE.Println(key, "found in override (via pflag):", val)
|
switch flag.Value.Type() {
|
||||||
|
case "int", "int8", "int16", "int32", "int64":
|
||||||
|
return cast.ToInt(flag.Value.String())
|
||||||
|
case "bool":
|
||||||
|
return cast.ToBool(flag.Value.String())
|
||||||
|
default:
|
||||||
return flag.Value.String()
|
return flag.Value.String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue