From 9e8198962c30ba5df5822a060ea68ac1d03ff35d Mon Sep 17 00:00:00 2001 From: David Calavera Date: Thu, 10 Dec 2015 13:14:17 -0500 Subject: [PATCH] Add FlagValue interface to support other flag systems. Using an interface allows people to use their favourite flag system with viper without being restricted to the semantics of pflag or the standard library. This change introduce two new functions `BindFlagValues` and `BindFlagValue` that behave like `BindFlags` and `BindFlag` but using the new interface as values. This change also introduces two internal structures to transform `*pflag.FlagSet` and `*pflag.Flag` into the new interface. This way, viper keeps working as expected for people that are currently using the pflag package without breaking backwards compatibility. Signed-off-by: David Calavera --- flags.go | 57 ++++++++++++++++++++++++++++++++++++++++++++ flags_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++ viper.go | 56 ++++++++++++++++++++++++++++--------------- 3 files changed, 160 insertions(+), 19 deletions(-) create mode 100644 flags.go create mode 100644 flags_test.go diff --git a/flags.go b/flags.go new file mode 100644 index 0000000..dd32f4e --- /dev/null +++ b/flags.go @@ -0,0 +1,57 @@ +package viper + +import "github.com/spf13/pflag" + +// FlagValueSet is an interface that users can implement +// to bind a set of flags to viper. +type FlagValueSet interface { + VisitAll(fn func(FlagValue)) +} + +// FlagValue is an interface that users can implement +// to bind different flags to viper. +type FlagValue interface { + HasChanged() bool + Name() string + ValueString() string + ValueType() string +} + +// pflagValueSet is a wrapper around *pflag.ValueSet +// that implements FlagValueSet. +type pflagValueSet struct { + flags *pflag.FlagSet +} + +// VisitAll iterates over all *pflag.Flag inside the *pflag.FlagSet. +func (p pflagValueSet) VisitAll(fn func(flag FlagValue)) { + p.flags.VisitAll(func(flag *pflag.Flag) { + fn(pflagValue{flag}) + }) +} + +// pflagValue is a wrapper aroung *pflag.flag +// that implements FlagValue +type pflagValue struct { + flag *pflag.Flag +} + +// HasChanges returns whether the flag has changes or not. +func (p pflagValue) HasChanged() bool { + return p.flag.Changed +} + +// Name returns the name of the flag. +func (p pflagValue) Name() string { + return p.flag.Name +} + +// ValueString returns the value of the flag as a string. +func (p pflagValue) ValueString() string { + return p.flag.Value.String() +} + +// ValueType returns the type of the flag as a string. +func (p pflagValue) ValueType() string { + return p.flag.Value.Type() +} diff --git a/flags_test.go b/flags_test.go new file mode 100644 index 0000000..5489278 --- /dev/null +++ b/flags_test.go @@ -0,0 +1,66 @@ +package viper + +import ( + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" +) + +func TestBindFlagValueSet(t *testing.T) { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + + var testValues = map[string]*string{ + "host": nil, + "port": nil, + "endpoint": nil, + } + + var mutatedTestValues = map[string]string{ + "host": "localhost", + "port": "6060", + "endpoint": "/public", + } + + for name, _ := range testValues { + testValues[name] = flagSet.String(name, "", "test") + } + + flagValueSet := pflagValueSet{flagSet} + + err := BindFlagValues(flagValueSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + flagSet.VisitAll(func(flag *pflag.Flag) { + flag.Value.Set(mutatedTestValues[flag.Name]) + flag.Changed = true + }) + + for name, expected := range mutatedTestValues { + assert.Equal(t, Get(name), expected) + } +} + +func TestBindFlagValue(t *testing.T) { + var testString = "testing" + var testValue = newStringValue(testString, &testString) + + flag := &pflag.Flag{ + Name: "testflag", + Value: testValue, + Changed: false, + } + + flagValue := pflagValue{flag} + BindFlagValue("testvalue", flagValue) + + assert.Equal(t, testString, Get("testvalue")) + + flag.Value.Set("testing_mutate") + flag.Changed = true //hack for pflag usage + + assert.Equal(t, "testing_mutate", Get("testvalue")) + +} diff --git a/viper.go b/viper.go index 6f204b1..3313f11 100644 --- a/viper.go +++ b/viper.go @@ -149,7 +149,7 @@ type Viper struct { override map[string]interface{} defaults map[string]interface{} kvstore map[string]interface{} - pflags map[string]*pflag.Flag + pflags map[string]FlagValue env map[string]string aliases map[string]string typeByDefValue bool @@ -166,7 +166,7 @@ func New() *Viper { v.override = make(map[string]interface{}) v.defaults = make(map[string]interface{}) v.kvstore = make(map[string]interface{}) - v.pflags = make(map[string]*pflag.Flag) + v.pflags = make(map[string]FlagValue) v.env = make(map[string]string) v.aliases = make(map[string]string) v.typeByDefValue = false @@ -467,13 +467,13 @@ func (v *Viper) Get(key string) interface{} { if val == nil { if flag, exists := v.pflags[lcaseKey]; exists { jww.TRACE.Println(key, "get pflag default", val) - switch flag.Value.Type() { + switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - val = cast.ToInt(flag.Value.String()) + val = cast.ToInt(flag.ValueString()) case "bool": - val = cast.ToBool(flag.Value.String()) + val = cast.ToBool(flag.ValueString()) default: - val = flag.Value.String() + val = flag.ValueString() } } } @@ -606,15 +606,10 @@ func (v *Viper) Unmarshal(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { - flags.VisitAll(func(flag *pflag.Flag) { - if err = v.BindPFlag(flag.Name, flag); err != nil { - return - } - }) - return nil + return v.BindFlagValues(pflagValueSet{flags}) } -// Bind a specific key to a flag (as used by cobra) +// Bind a specific key to a pflag (as used by cobra) // Example(where serverCmd is a Cobra instance): // // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") @@ -622,6 +617,29 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { // func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) { + return v.BindFlagValue(key, pflagValue{flag}) +} + +// Bind a full FlagValue set to the configuration, using each flag's long +// name as the config key. +func BindFlagValues(flags FlagValueSet) (err error) { return v.BindFlagValues(flags) } +func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { + flags.VisitAll(func(flag FlagValue) { + if err = v.BindFlagValue(flag.Name(), flag); err != nil { + return + } + }) + return nil +} + +// Bind a specific key to a FlagValue. +// Example(where serverCmd is a Cobra instance): +// +// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") +// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) +// +func BindFlagValue(key string, flag FlagValue) (err error) { return v.BindFlagValue(key, flag) } +func (v *Viper) BindFlagValue(key string, flag FlagValue) (err error) { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } @@ -666,15 +684,15 @@ func (v *Viper) find(key string) interface{} { // PFlag Override first flag, exists := v.pflags[key] - if exists && flag.Changed { - jww.TRACE.Println(key, "found in override (via pflag):", flag.Value) - switch flag.Value.Type() { + if exists && flag.HasChanged() { + jww.TRACE.Println(key, "found in override (via pflag):", flag.ValueString()) + switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.Value.String()) + return cast.ToInt(flag.ValueString()) case "bool": - return cast.ToBool(flag.Value.String()) + return cast.ToBool(flag.ValueString()) default: - return flag.Value.String() + return flag.ValueString() } }