adding cascading file support, off by default

This commit is contained in:
Bill Robbins 2015-02-06 14:35:00 -06:00
parent d8f2aa78d4
commit 24b9be4805
3 changed files with 173 additions and 4 deletions

1
.gitignore vendored
View file

@ -6,6 +6,7 @@
# Folders # Folders
_obj _obj
_test _test
.idea
# Architecture specific extensions/prefixes # Architecture specific extensions/prefixes
*.[568vq] *.[568vq]

View file

@ -77,11 +77,13 @@ type viper struct {
envPrefix string envPrefix string
automaticEnvApplied bool automaticEnvApplied bool
cascadeConfigurations bool
config map[string]interface{} config map[string]interface{}
override map[string]interface{} override map[string]interface{}
defaults map[string]interface{} defaults map[string]interface{}
kvstore map[string]interface{} kvstore map[string]interface{}
cascadingConfigs map[string]map[string]interface{}
pflags map[string]*pflag.Flag pflags map[string]*pflag.Flag
env map[string]string env map[string]string
aliases map[string]string aliases map[string]string
@ -98,6 +100,7 @@ func New() *viper {
v.pflags = make(map[string]*pflag.Flag) v.pflags = make(map[string]*pflag.Flag)
v.env = make(map[string]string) v.env = make(map[string]string)
v.aliases = make(map[string]string) v.aliases = make(map[string]string)
v.cascadeConfigurations = false
return v return v
} }
@ -136,6 +139,12 @@ func (v *viper) SetEnvPrefix(in string) {
} }
} }
// Enable cascading configuration values for files. Will traverse down
// ConfigPaths in an attempt to find keys
func (v *viper) EnableCascading(enable bool){
v.cascadeConfigurations = enable;
}
func (v *viper) mergeWithEnvPrefix(in string) string { func (v *viper) mergeWithEnvPrefix(in string) string {
if v.envPrefix != "" { if v.envPrefix != "" {
return strings.ToUpper(v.envPrefix + "_" + in) return strings.ToUpper(v.envPrefix + "_" + in)
@ -389,6 +398,7 @@ func (v *viper) BindEnv(input ...string) (err error) {
func (v *viper) find(key string) interface{} { func (v *viper) find(key string) interface{} {
var val interface{} var val interface{}
var exists bool var exists bool
var file string
// if the requested key is an alias, then return the proper key // if the requested key is an alias, then return the proper key
key = v.realKey(key) key = v.realKey(key)
@ -434,6 +444,15 @@ func (v *viper) find(key string) interface{} {
return val return val
} }
if( v.cascadeConfigurations) {
//cascade down the rest of the files
val, exists, file = v.findCascading(key)
if exists {
jww.TRACE.Printf("%s found in config: %s (%s)", key, val, file)
return val
}
}
val, exists = v.kvstore[key] val, exists = v.kvstore[key]
if exists { if exists {
jww.TRACE.Println(key, "found in key/value store:", val) jww.TRACE.Println(key, "found in key/value store:", val)
@ -449,6 +468,49 @@ func (v *viper) find(key string) interface{} {
return nil return nil
} }
func (v *viper) findCascading(key string) (interface{}, bool, string) {
if( v.cascadingConfigs != nil ){
for file,config := range v.cascadingConfigs {
result := config[key]
if( result != nil ){
return result,true,file
}
}
}
v.cascadingConfigs = make(map[string]map[string]interface{})
configFiles := v.findAllConfigFiles()
for _, configFile := range configFiles {
if(v.cascadingConfigs[configFile] != nil){
//already cached
continue
}
jww.TRACE.Printf("Looking in %s for key %s",configFile,key)
file, err := ioutil.ReadFile(configFile)
if err != nil {
jww.ERROR.Print(err)
continue
}
jww.TRACE.Printf("marshalling %s for cascading",configFile)
var config = make(map[string]interface{})
marshallConfigReader(bytes.NewReader(file), config, filepath.Ext(configFile)[1:])
v.cascadingConfigs[configFile] = config
result := config[key]
if( result != nil){
return result,true,configFile
}
}
return "", false, ""
}
// Check to see if the key has been set in any of the data locations // Check to see if the key has been set in any of the data locations
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }
func (v *viper) IsSet(key string) bool { func (v *viper) IsSet(key string) bool {
@ -733,26 +795,41 @@ func (v *viper) searchInPath(in string) (filename string) {
func (v *viper) findConfigFile() (string, error) { func (v *viper) findConfigFile() (string, error) {
jww.INFO.Println("Searching for config in ", v.configPaths) jww.INFO.Println("Searching for config in ", v.configPaths)
var validFiles = v.findAllConfigFiles()
if len(validFiles) == 0 {
return "", fmt.Errorf("config file not found in: %s", v.configPaths)
}
return validFiles[0], nil
}
func (v *viper) findAllConfigFiles() []string {
var validFiles []string
for _, cp := range v.configPaths { for _, cp := range v.configPaths {
file := v.searchInPath(cp) file := v.searchInPath(cp)
if file != "" { if file != "" {
return file, nil jww.TRACE.Println("Found config file in: %s",file)
validFiles = append(validFiles, file)
} }
} }
cwd, _ := findCWD() cwd, _ := findCWD()
file := v.searchInPath(cwd) file := v.searchInPath(cwd)
if file != "" { if file != "" {
return file, nil validFiles = append(validFiles, file)
} }
// try the current working directory // try the current working directory
wd, _ := os.Getwd() wd, _ := os.Getwd()
file = v.searchInPath(wd) file = v.searchInPath(wd)
if file != "" { if file != "" {
return file, nil validFiles = append(validFiles, file)
} }
return "", fmt.Errorf("config file not found in: %s", v.configPaths)
return validFiles
} }
func Debug() { v.Debug() } func Debug() { v.Debug() }

View file

@ -12,6 +12,9 @@ import (
"sort" "sort"
"testing" "testing"
"time" "time"
"os/exec"
"path"
"io/ioutil"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -352,3 +355,91 @@ func TestBoundCaseSensitivity(t *testing.T) {
assert.Equal(t, "green", Get("eyes")) assert.Equal(t, "green", Get("eyes"))
} }
func TestCanCascadeConfigurationValues(t *testing.T) {
v2 := New()
generateCascadingTests(v2,"cascading")
v2.ReadInConfig()
v2.EnableCascading(true)
assert.Equal(t,"high",v2.GetString("0"),"Key 0 should be high")
assert.Equal(t,"med",v2.GetString("1"),"Key 1 should be med")
assert.Equal(t,"low",v2.GetString("2"),"key 2 should be low")
v2.EnableCascading(false)
assert.Nil(t,v2.Get("1"),"With enable cascading disabled, no value for 1 should exist")
assert.Nil(t,v2.Get("2"),"With enable cascading disabled, no value for 2 should exist")
}
func TestFindAllConfigPaths(t *testing.T){
v2 := New()
file := "viper_test"
var expected = generateCascadingTests(v2,file)
found := v2.findAllConfigFiles()
for _,fp := range expected{
command := exec.Command("rm",fp)
command.Run()
}
assert.Equal(t,expected,found,"All files should exist")
}
func generateCascadingTests(v2 *viper, file_name string) []string {
v2.SetConfigName(file_name)
tmp := os.Getenv("TMPDIR")
// $TMPDIR/a > $TMPDIR/b > %TMPDIR
paths := []string{path.Join(tmp,"a"),path.Join(tmp,"b"),tmp}
v2.SetConfigName(file_name)
var expected []string
for idx,fp := range paths {
v2.AddConfigPath(fp)
exec.Command("mkdir","-m","777",fp).Run()
full_path := path.Join(fp,file_name + ".json")
var val string
switch idx{
case 0 :
val = "high"
break
case 1 :
val = "med"
break
case 2 :
val = "low"
}
config := "{"
for i := 0; i <= idx; i++ {
config += fmt.Sprintf("\"%d\": \"%s\"",i,val)
if( i == idx) {
config += "\n"
}else{
config += ",\n"
}
}
config += "}"
ioutil.WriteFile(full_path,[]byte(config),0777)
expected = append(expected,full_path)
}
return expected
}