From 24b9be48054e0fe8b97672f3c59a891c1904e229 Mon Sep 17 00:00:00 2001 From: Bill Robbins Date: Fri, 6 Feb 2015 14:35:00 -0600 Subject: [PATCH] adding cascading file support, off by default --- .gitignore | 1 + viper.go | 85 ++++++++++++++++++++++++++++++++++++++++++++--- viper_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 8365624..31ec60d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ # Folders _obj _test +.idea # Architecture specific extensions/prefixes *.[568vq] diff --git a/viper.go b/viper.go index 89834b7..a4baa45 100644 --- a/viper.go +++ b/viper.go @@ -77,11 +77,13 @@ type viper struct { envPrefix string automaticEnvApplied bool + cascadeConfigurations bool config map[string]interface{} override map[string]interface{} defaults map[string]interface{} kvstore map[string]interface{} + cascadingConfigs map[string]map[string]interface{} pflags map[string]*pflag.Flag env map[string]string aliases map[string]string @@ -98,6 +100,7 @@ func New() *viper { v.pflags = make(map[string]*pflag.Flag) v.env = make(map[string]string) v.aliases = make(map[string]string) + v.cascadeConfigurations = false return v } @@ -136,6 +139,12 @@ func (v *viper) SetEnvPrefix(in string) { } } +// Enable cascading configuration values for files. Will traverse down +// ConfigPaths in an attempt to find keys +func (v *viper) EnableCascading(enable bool){ + v.cascadeConfigurations = enable; +} + func (v *viper) mergeWithEnvPrefix(in string) string { if v.envPrefix != "" { return strings.ToUpper(v.envPrefix + "_" + in) @@ -389,6 +398,7 @@ func (v *viper) BindEnv(input ...string) (err error) { func (v *viper) find(key string) interface{} { var val interface{} var exists bool + var file string // if the requested key is an alias, then return the proper key key = v.realKey(key) @@ -434,6 +444,15 @@ func (v *viper) find(key string) interface{} { return val } + if( v.cascadeConfigurations) { + //cascade down the rest of the files + val, exists, file = v.findCascading(key) + if exists { + jww.TRACE.Printf("%s found in config: %s (%s)", key, val, file) + return val + } + } + val, exists = v.kvstore[key] if exists { jww.TRACE.Println(key, "found in key/value store:", val) @@ -449,6 +468,49 @@ func (v *viper) find(key string) interface{} { return nil } +func (v *viper) findCascading(key string) (interface{}, bool, string) { + + if( v.cascadingConfigs != nil ){ + for file,config := range v.cascadingConfigs { + result := config[key] + if( result != nil ){ + return result,true,file + } + } + } + + v.cascadingConfigs = make(map[string]map[string]interface{}) + + configFiles := v.findAllConfigFiles() + for _, configFile := range configFiles { + + if(v.cascadingConfigs[configFile] != nil){ + //already cached + continue + } + + jww.TRACE.Printf("Looking in %s for key %s",configFile,key) + file, err := ioutil.ReadFile(configFile) + if err != nil { + jww.ERROR.Print(err) + continue + } + + jww.TRACE.Printf("marshalling %s for cascading",configFile) + var config = make(map[string]interface{}) + + marshallConfigReader(bytes.NewReader(file), config, filepath.Ext(configFile)[1:]) + v.cascadingConfigs[configFile] = config + + result := config[key] + + if( result != nil){ + return result,true,configFile + } + } + return "", false, "" +} + // Check to see if the key has been set in any of the data locations func IsSet(key string) bool { return v.IsSet(key) } func (v *viper) IsSet(key string) bool { @@ -733,26 +795,41 @@ func (v *viper) searchInPath(in string) (filename string) { func (v *viper) findConfigFile() (string, error) { jww.INFO.Println("Searching for config in ", v.configPaths) + var validFiles = v.findAllConfigFiles() + + if len(validFiles) == 0 { + return "", fmt.Errorf("config file not found in: %s", v.configPaths) + } + + return validFiles[0], nil +} + +func (v *viper) findAllConfigFiles() []string { + + var validFiles []string for _, cp := range v.configPaths { file := v.searchInPath(cp) if file != "" { - return file, nil + jww.TRACE.Println("Found config file in: %s",file) + validFiles = append(validFiles, file) } } cwd, _ := findCWD() file := v.searchInPath(cwd) if file != "" { - return file, nil + validFiles = append(validFiles, file) } // try the current working directory wd, _ := os.Getwd() file = v.searchInPath(wd) if file != "" { - return file, nil + validFiles = append(validFiles, file) } - return "", fmt.Errorf("config file not found in: %s", v.configPaths) + + + return validFiles } func Debug() { v.Debug() } diff --git a/viper_test.go b/viper_test.go index ce1aa3c..0bbe84f 100644 --- a/viper_test.go +++ b/viper_test.go @@ -12,6 +12,9 @@ import ( "sort" "testing" "time" + "os/exec" + "path" + "io/ioutil" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" @@ -352,3 +355,91 @@ func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "green", Get("eyes")) } + +func TestCanCascadeConfigurationValues(t *testing.T) { + + v2 := New() + + generateCascadingTests(v2,"cascading") + + v2.ReadInConfig() + v2.EnableCascading(true) + + assert.Equal(t,"high",v2.GetString("0"),"Key 0 should be high") + assert.Equal(t,"med",v2.GetString("1"),"Key 1 should be med") + assert.Equal(t,"low",v2.GetString("2"),"key 2 should be low") + + v2.EnableCascading(false) + + assert.Nil(t,v2.Get("1"),"With enable cascading disabled, no value for 1 should exist") + assert.Nil(t,v2.Get("2"),"With enable cascading disabled, no value for 2 should exist") +} + +func TestFindAllConfigPaths(t *testing.T){ + + v2 := New() + + file := "viper_test" + + var expected = generateCascadingTests(v2,file) + + found := v2.findAllConfigFiles() + + for _,fp := range expected{ + command := exec.Command("rm",fp) + command.Run() + } + + assert.Equal(t,expected,found,"All files should exist") +} + +func generateCascadingTests(v2 *viper, file_name string) []string { + + v2.SetConfigName(file_name) + + tmp := os.Getenv("TMPDIR") + // $TMPDIR/a > $TMPDIR/b > %TMPDIR + paths := []string{path.Join(tmp,"a"),path.Join(tmp,"b"),tmp} + + v2.SetConfigName(file_name) + + var expected []string + + for idx,fp := range paths { + v2.AddConfigPath(fp) + + exec.Command("mkdir","-m","777",fp).Run() + + full_path := path.Join(fp,file_name + ".json") + + var val string + switch idx{ + case 0 : + val = "high" + break + case 1 : + val = "med" + break + case 2 : + val = "low" + } + + config := "{" + for i := 0; i <= idx; i++ { + config += fmt.Sprintf("\"%d\": \"%s\"",i,val) + if( i == idx) { + config += "\n" + }else{ + config += ",\n" + } + } + + config += "}" + + ioutil.WriteFile(full_path,[]byte(config),0777) + + expected = append(expected,full_path) + } + + return expected +}