diff --git a/flags.go b/flags.go new file mode 100644 index 0000000..dd32f4e --- /dev/null +++ b/flags.go @@ -0,0 +1,57 @@ +package viper + +import "github.com/spf13/pflag" + +// FlagValueSet is an interface that users can implement +// to bind a set of flags to viper. +type FlagValueSet interface { + VisitAll(fn func(FlagValue)) +} + +// FlagValue is an interface that users can implement +// to bind different flags to viper. +type FlagValue interface { + HasChanged() bool + Name() string + ValueString() string + ValueType() string +} + +// pflagValueSet is a wrapper around *pflag.ValueSet +// that implements FlagValueSet. +type pflagValueSet struct { + flags *pflag.FlagSet +} + +// VisitAll iterates over all *pflag.Flag inside the *pflag.FlagSet. +func (p pflagValueSet) VisitAll(fn func(flag FlagValue)) { + p.flags.VisitAll(func(flag *pflag.Flag) { + fn(pflagValue{flag}) + }) +} + +// pflagValue is a wrapper aroung *pflag.flag +// that implements FlagValue +type pflagValue struct { + flag *pflag.Flag +} + +// HasChanges returns whether the flag has changes or not. +func (p pflagValue) HasChanged() bool { + return p.flag.Changed +} + +// Name returns the name of the flag. +func (p pflagValue) Name() string { + return p.flag.Name +} + +// ValueString returns the value of the flag as a string. +func (p pflagValue) ValueString() string { + return p.flag.Value.String() +} + +// ValueType returns the type of the flag as a string. +func (p pflagValue) ValueType() string { + return p.flag.Value.Type() +} diff --git a/flags_test.go b/flags_test.go new file mode 100644 index 0000000..5489278 --- /dev/null +++ b/flags_test.go @@ -0,0 +1,66 @@ +package viper + +import ( + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +func TestBindFlagValueSet(t *testing.T) { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + + var testValues = map[string]*string{ + "host": nil, + "port": nil, + "endpoint": nil, + } + + var mutatedTestValues = map[string]string{ + "host": "localhost", + "port": "6060", + "endpoint": "/public", + } + + for name, _ := range testValues { + testValues[name] = flagSet.String(name, "", "test") + } + + flagValueSet := pflagValueSet{flagSet} + + err := BindFlagValues(flagValueSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + flagSet.VisitAll(func(flag *pflag.Flag) { + flag.Value.Set(mutatedTestValues[flag.Name]) + flag.Changed = true + }) + + for name, expected := range mutatedTestValues { + assert.Equal(t, Get(name), expected) + } +} + +func TestBindFlagValue(t *testing.T) { + var testString = "testing" + var testValue = newStringValue(testString, &testString) + + flag := &pflag.Flag{ + Name: "testflag", + Value: testValue, + Changed: false, + } + + flagValue := pflagValue{flag} + BindFlagValue("testvalue", flagValue) + + assert.Equal(t, testString, Get("testvalue")) + + flag.Value.Set("testing_mutate") + flag.Changed = true //hack for pflag usage + + assert.Equal(t, "testing_mutate", Get("testvalue")) + +} diff --git a/viper.go b/viper.go index a671c1f..b253c75 100644 --- a/viper.go +++ b/viper.go @@ -149,7 +149,7 @@ type Viper struct { override map[string]interface{} defaults map[string]interface{} kvstore map[string]interface{} - pflags map[string]*pflag.Flag + pflags map[string]FlagValue env map[string]string aliases map[string]string typeByDefValue bool @@ -166,7 +166,7 @@ func New() *Viper { v.override = make(map[string]interface{}) v.defaults = make(map[string]interface{}) v.kvstore = make(map[string]interface{}) - v.pflags = make(map[string]*pflag.Flag) + v.pflags = make(map[string]FlagValue) v.env = make(map[string]string) v.aliases = make(map[string]string) v.typeByDefValue = false @@ -467,13 +467,13 @@ func (v *Viper) Get(key string) interface{} { if val == nil { if flag, exists := v.pflags[lcaseKey]; exists { jww.TRACE.Println(key, "get pflag default", val) - switch flag.Value.Type() { + switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - val = cast.ToInt(flag.Value.String()) + val = cast.ToInt(flag.ValueString()) case "bool": - val = cast.ToBool(flag.Value.String()) + val = cast.ToBool(flag.ValueString()) default: - val = flag.Value.String() + val = flag.ValueString() } } } @@ -618,15 +618,10 @@ func (v *Viper) Unmarshal(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { - flags.VisitAll(func(flag *pflag.Flag) { - if err = v.BindPFlag(flag.Name, flag); err != nil { - return - } - }) - return nil + return v.BindFlagValues(pflagValueSet{flags}) } -// Bind a specific key to a flag (as used by cobra) +// Bind a specific key to a pflag (as used by cobra) // Example(where serverCmd is a Cobra instance): // // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") @@ -634,6 +629,29 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { // func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) { + return v.BindFlagValue(key, pflagValue{flag}) +} + +// Bind a full FlagValue set to the configuration, using each flag's long +// name as the config key. +func BindFlagValues(flags FlagValueSet) (err error) { return v.BindFlagValues(flags) } +func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { + flags.VisitAll(func(flag FlagValue) { + if err = v.BindFlagValue(flag.Name(), flag); err != nil { + return + } + }) + return nil +} + +// Bind a specific key to a FlagValue. +// Example(where serverCmd is a Cobra instance): +// +// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") +// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) +// +func BindFlagValue(key string, flag FlagValue) (err error) { return v.BindFlagValue(key, flag) } +func (v *Viper) BindFlagValue(key string, flag FlagValue) (err error) { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } @@ -678,15 +696,15 @@ func (v *Viper) find(key string) interface{} { // PFlag Override first flag, exists := v.pflags[key] - if exists && flag.Changed { - jww.TRACE.Println(key, "found in override (via pflag):", flag.Value) - switch flag.Value.Type() { + if exists && flag.HasChanged() { + jww.TRACE.Println(key, "found in override (via pflag):", flag.ValueString()) + switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.Value.String()) + return cast.ToInt(flag.ValueString()) case "bool": - return cast.ToBool(flag.Value.String()) + return cast.ToBool(flag.ValueString()) default: - return flag.Value.String() + return flag.ValueString() } }