diff --git a/viper.go b/viper.go new file mode 100644 index 0000000..bdfa5aa --- /dev/null +++ b/viper.go @@ -0,0 +1,348 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package viper + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path" + "path/filepath" + "time" + + "github.com/BurntSushi/toml" + "github.com/kr/pretty" + "github.com/spf13/cast" + jww "github.com/spf13/jwalterweatherman" + "gopkg.in/yaml.v1" +) + +// A set of paths to look for the config file in +var configPaths []string + +// Name of file to look for inside the path +var configName string = "config" + +// extensions Supported +var SupportedExts []string = []string{"json", "toml", "yaml"} +var configFile string +var configType string + +var config map[string]interface{} = make(map[string]interface{}) +var override map[string]interface{} = make(map[string]interface{}) +var defaults map[string]interface{} = make(map[string]interface{}) +var aliases map[string]string = make(map[string]string) + +func SetConfigFile(in string) { + if in != "" { + configFile = in + } +} + +func AddConfigPath(in string) { + if in != "" { + absin := absPathify(in) + jww.INFO.Println("adding", absin, "to paths to search") + if !stringInSlice(absin, configPaths) { + configPaths = append(configPaths, absin) + } + } +} + +func GetString(key string) string { + return cast.ToString(Get(key)) +} + +func GetBool(key string) bool { + return cast.ToBool(Get(key)) +} + +func GetInt(key string) int { + return cast.ToInt(Get(key)) +} + +func GetFloat64(key string) float64 { + return cast.ToFloat64(Get(key)) +} + +func GetTime(key string) time.Time { + return cast.ToTime(Get(key)) +} + +func GetStringArray(key string) []string { + return cast.ToStringArray(Get(key)) +} + +func find(key string) interface{} { + var val interface{} + var exists bool + + // If the requested key is an alias, then return the proper key + newkey, exists := aliases[key] + if exists { + return find(newkey) + } + + val, exists = override[key] + if exists { + jww.TRACE.Println(key, "found in override:", val) + return val + } + + val, exists = config[key] + if exists { + jww.TRACE.Println(key, "found in config:", val) + return val + } + + val, exists = defaults[key] + if exists { + jww.TRACE.Println(key, "found in defaults:", val) + return val + } + + return nil +} + +func Get(key string) interface{} { + v := find(key) + + if v == nil { + return nil + } + + switch v.(type) { + case bool: + return cast.ToBool(v) + case string: + return cast.ToString(v) + case int64, int32, int16, int8, int: + return cast.ToInt(v) + case float64, float32: + return cast.ToFloat64(v) + case time.Time: + return cast.ToTime(v) + case []string: + return v + } + return v +} + +func IsSet(key string) bool { + t := Get(key) + return t != nil +} + +func RegisterAlias(alias string, key string) { + aliases[alias] = key +} + +func InConfig(key string) bool { + _, exists := config[key] + return exists +} + +func SetDefault(key string, value interface{}) { + // If alias passed in, then set the proper default + newkey, exists := aliases[key] + if exists { + defaults[newkey] = value + } else { + defaults[key] = value + } +} + +func Set(key string, value interface{}) { + // If alias passed in, then set the proper override + newkey, exists := aliases[key] + if exists { + override[newkey] = value + } else { + override[key] = value + } +} + +func ReadInConfig() { + jww.INFO.Println("Attempting to read in config file") + if !stringInSlice(getConfigType(), SupportedExts) { + jww.ERROR.Fatalf("Unsupported Config Type %q", getConfigType()) + } + + file, err := ioutil.ReadFile(getConfigFile()) + if err == nil { + MarshallReader(bytes.NewReader(file)) + } +} + +func MarshallReader(in io.Reader) { + buf := new(bytes.Buffer) + buf.ReadFrom(in) + + switch getConfigType() { + case "yaml": + if err := yaml.Unmarshal(buf.Bytes(), &config); err != nil { + jww.ERROR.Fatalf("Error parsing config: %s", err) + } + + case "json": + if err := json.Unmarshal(buf.Bytes(), &config); err != nil { + jww.ERROR.Fatalf("Error parsing config: %s", err) + } + + case "toml": + if _, err := toml.Decode(buf.String(), &config); err != nil { + jww.ERROR.Fatalf("Error parsing config: %s", err) + } + } +} + +func SetConfigName(in string) { + if in != "" { + configName = in + } +} + +func SetConfigType(in string) { + if in != "" { + configType = in + } +} + +func getConfigType() string { + if configType != "" { + return configType + } + + cf := getConfigFile() + ext := path.Ext(cf) + + if len(ext) > 1 { + return ext[1:] + } else { + return "" + } +} + +func getConfigFile() string { + // if explicitly set, then use it + if configFile != "" { + return configFile + } + + cf, err := findConfigFile() + if err != nil { + jww.ERROR.Println(err) + } else { + configFile = cf + return getConfigFile() + } + return "" +} + +func searchInPath(in string) (filename string) { + jww.DEBUG.Println("Searching for config in ", in) + for _, ext := range SupportedExts { + + jww.DEBUG.Println("Checking for", path.Join(in, configName+"."+ext)) + if b, _ := exists(path.Join(in, configName+"."+ext)); b { + jww.DEBUG.Println("Found: ", path.Join(in, configName+"."+ext)) + return path.Join(in, configName+"."+ext) + } + } + + return "" +} + +func findConfigFile() (string, error) { + jww.INFO.Println("Searching for config in ", configPaths) + + for _, cp := range configPaths { + file := searchInPath(cp) + if file != "" { + return file, nil + } + } + cwd, _ := findCWD() + + file := searchInPath(cwd) + if file != "" { + return file, nil + } + + return "", fmt.Errorf("config file not found in: %s", configPaths) +} + +func findCWD() (string, error) { + serverFile, err := filepath.Abs(os.Args[0]) + + if err != nil { + return "", fmt.Errorf("Can't get absolute path for executable: %v", err) + } + + path := filepath.Dir(serverFile) + realFile, err := filepath.EvalSymlinks(serverFile) + + if err != nil { + if _, err = os.Stat(serverFile + ".exe"); err == nil { + realFile = filepath.Clean(serverFile + ".exe") + } + } + + if err == nil && realFile != serverFile { + path = filepath.Dir(realFile) + } + + return path, nil +} + +// Check if File / Directory Exists +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func stringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +func absPathify(inPath string) string { + jww.INFO.Println("Trying to resolve absolute path to", inPath) + if filepath.IsAbs(inPath) { + return filepath.Clean(inPath) + } + + p, err := filepath.Abs(inPath) + if err == nil { + return filepath.Clean(p) + } else { + jww.ERROR.Println("Couldn't discover absolute path") + jww.ERROR.Println(err) + } + return "" +} + +func Debug() { + fmt.Println("Config:") + pretty.Println(config) + fmt.Println("Defaults:") + pretty.Println(defaults) + fmt.Println("Override:") + pretty.Println(override) +} diff --git a/viper_test.go b/viper_test.go new file mode 100644 index 0000000..1a4b624 --- /dev/null +++ b/viper_test.go @@ -0,0 +1,119 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package viper + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +var yamlExample = []byte(`Hacker: true +name: steve +hobbies: +- skateboarding +- snowboarding +- go +age: 35`) + +var tomlExample = []byte(` +title = "TOML Example" + +[owner] +organization = "MongoDB" +bio = "MongoDB Chief Developer Advocate & Hacker at Large" +dob = 1979-05-27T07:32:00Z # First class dates? Why not?`) + +var jsonExample = []byte(`{ +"id": "0001", +"type": "donut", +"name": "Cake", +"ppu": 0.55, +"batters": { + "batter": [ + { "type": "Regular" }, + { "type": "Chocolate" }, + { "type": "Blueberry" }, + { "type": "Devil's Food" } + ] + } +}`) + +func reset() { + configPaths = nil + configName = "config" + + // extensions Supported + SupportedExts = []string{"json", "toml", "yaml"} + configFile = "" + configType = "" + + config = make(map[string]interface{}) + override = make(map[string]interface{}) + defaults = make(map[string]interface{}) + aliases = make(map[string]string) +} + +func TestBasics(t *testing.T) { + SetConfigFile("/tmp/config.yaml") + assert.Equal(t, "/tmp/config.yaml", getConfigFile()) +} + +func TestDefault(t *testing.T) { + SetDefault("age", 45) + assert.Equal(t, 45, Get("age")) +} + +func TestMarshalling(t *testing.T) { + SetConfigType("yaml") + r := bytes.NewReader(yamlExample) + + MarshallReader(r) + assert.True(t, InConfig("name")) + assert.False(t, InConfig("state")) + assert.Equal(t, "steve", Get("name")) + assert.Equal(t, []interface{}{"skateboarding", "snowboarding", "go"}, Get("hobbies")) + assert.Equal(t, 35, Get("age")) +} + +func TestOverrides(t *testing.T) { + Set("age", 40) + assert.Equal(t, 40, Get("age")) +} + +func TestDefaultPost(t *testing.T) { + assert.NotEqual(t, "NYC", Get("state")) + SetDefault("state", "NYC") + assert.Equal(t, "NYC", Get("state")) +} + +func TestAliases(t *testing.T) { + RegisterAlias("years", "age") + assert.Equal(t, 40, Get("years")) + Set("years", 45) + assert.Equal(t, 45, Get("age")) +} + +func TestJSON(t *testing.T) { + SetConfigType("json") + r := bytes.NewReader(jsonExample) + + MarshallReader(r) + assert.Equal(t, "0001", Get("id")) +} + +func TestTOML(t *testing.T) { + SetConfigType("toml") + r := bytes.NewReader(tomlExample) + + MarshallReader(r) + assert.Equal(t, "TOML Example", Get("title")) +} + +func TestCaseSensitive(t *testing.T) { + assert.Equal(t, true, Get("Hacker")) +}