diff --git a/viper.go b/viper.go index d9c5e67..af6e5ed 100644 --- a/viper.go +++ b/viper.go @@ -405,15 +405,33 @@ func (v *Viper) Get(key string) interface{} { if val == nil { source := v.find(path[0]) - if source == nil { - return nil + if source != nil { + if reflect.TypeOf(source).Kind() == reflect.Map { + val = v.searchMap(cast.ToStringMap(source), path[1:]) + } } + } - if reflect.TypeOf(source).Kind() == reflect.Map { - val = v.searchMap(cast.ToStringMap(source), path[1:]) + // if no other value is returned and a flag does exist for the value, + // get the flag's value even if the flag's value has not changed + if val == nil { + if flag, exists := v.pflags[lcaseKey]; exists { + jww.TRACE.Println(key, "get pflag default", val) + switch flag.Value.Type() { + case "int", "int8", "int16", "int32", "int64": + val = cast.ToInt(flag.Value.String()) + case "bool": + val = cast.ToBool(flag.Value.String()) + default: + val = flag.Value.String() + } } } + if val == nil { + return nil + } + var valType interface{} if !v.typeByDefValue { valType = val @@ -539,22 +557,11 @@ func (v *Viper) Unmarshal(rawVal interface{}) error { func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { flags.VisitAll(func(flag *pflag.Flag) { - if err != nil { - // an error has been encountered in one of the previous flags + if err = v.BindPFlag(flag.Name, flag); err != nil { return } - - err = v.BindPFlag(flag.Name, flag) - switch flag.Value.Type() { - case "int", "int8", "int16", "int32", "int64": - v.SetDefault(flag.Name, cast.ToInt(flag.Value.String())) - case "bool": - v.SetDefault(flag.Name, cast.ToBool(flag.Value.String())) - default: - v.SetDefault(flag.Name, flag.Value.String()) - } }) - return + return nil } // Bind a specific key to a flag (as used by cobra) @@ -569,15 +576,6 @@ func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) { return fmt.Errorf("flag for %q is nil", key) } v.pflags[strings.ToLower(key)] = flag - - switch flag.Value.Type() { - case "int", "int8", "int16", "int32", "int64": - v.SetDefault(key, cast.ToInt(flag.Value.String())) - case "bool": - v.SetDefault(key, cast.ToBool(flag.Value.String())) - default: - v.SetDefault(key, flag.Value.String()) - } return nil } @@ -618,9 +616,14 @@ func (v *Viper) find(key string) interface{} { // PFlag Override first flag, exists := v.pflags[key] - if exists { - if flag.Changed { - jww.TRACE.Println(key, "found in override (via pflag):", val) + if exists && flag.Changed { + jww.TRACE.Println(key, "found in override (via pflag):", flag.Value) + switch flag.Value.Type() { + case "int", "int8", "int16", "int32", "int64": + return cast.ToInt(flag.Value.String()) + case "bool": + return cast.ToBool(flag.Value.String()) + default: return flag.Value.String() } }