From 29f1858f8782bd55fa432c459f7e64b7032b3c33 Mon Sep 17 00:00:00 2001 From: spf13 Date: Fri, 5 Dec 2014 03:55:51 +0100 Subject: [PATCH] Viper now supports multiple vipers. No API changes. --- util.go | 112 ++++++++++ viper.go | 549 +++++++++++++++++++++++--------------------------- viper_test.go | 16 +- 3 files changed, 373 insertions(+), 304 deletions(-) create mode 100644 util.go diff --git a/util.go b/util.go new file mode 100644 index 0000000..1875cc5 --- /dev/null +++ b/util.go @@ -0,0 +1,112 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Viper is a application configuration system. +// It believes that applications can be configured a variety of ways +// via flags, ENVIRONMENT variables, configuration files retrieved +// from the file system, or a remote key/value store. + +package viper + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + jww "github.com/spf13/jwalterweatherman" +) + +func insensativiseMap(m map[string]interface{}) { + for key, val := range m { + lower := strings.ToLower(key) + if key != lower { + delete(m, key) + m[lower] = val + } + } +} + +func absPathify(inPath string) string { + jww.INFO.Println("Trying to resolve absolute path to", inPath) + + if strings.HasPrefix(inPath, "$HOME") { + inPath = userHomeDir() + inPath[5:] + } + + if strings.HasPrefix(inPath, "$") { + end := strings.Index(inPath, string(os.PathSeparator)) + inPath = os.Getenv(inPath[1:end]) + inPath[end:] + } + + if filepath.IsAbs(inPath) { + return filepath.Clean(inPath) + } + + p, err := filepath.Abs(inPath) + if err == nil { + return filepath.Clean(p) + } else { + jww.ERROR.Println("Couldn't discover absolute path") + jww.ERROR.Println(err) + } + return "" +} + +// Check if File / Directory Exists +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func stringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +func userHomeDir() string { + if runtime.GOOS == "windows" { + home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") + if home == "" { + home = os.Getenv("USERPROFILE") + } + return home + } + return os.Getenv("HOME") +} + +func findCWD() (string, error) { + serverFile, err := filepath.Abs(os.Args[0]) + + if err != nil { + return "", fmt.Errorf("Can't get absolute path for executable: %v", err) + } + + path := filepath.Dir(serverFile) + realFile, err := filepath.EvalSymlinks(serverFile) + + if err != nil { + if _, err = os.Stat(serverFile + ".exe"); err == nil { + realFile = filepath.Clean(serverFile + ".exe") + } + } + + if err == nil && realFile != serverFile { + path = filepath.Dir(realFile) + } + + return path, nil +} diff --git a/viper.go b/viper.go index 3a55f23..800230f 100644 --- a/viper.go +++ b/viper.go @@ -27,7 +27,6 @@ import ( "os" "path/filepath" "reflect" - "runtime" "strings" "time" @@ -41,6 +40,65 @@ import ( "gopkg.in/yaml.v1" ) +var v *viper + +func init() { + v = New() +} + +type UnsupportedConfigError string + +func (str UnsupportedConfigError) Error() string { + return fmt.Sprintf("Unsupported Config Type %q", string(str)) +} + +type UnsupportedRemoteProviderError string + +func (str UnsupportedRemoteProviderError) Error() string { + return fmt.Sprintf("Unsupported Remote Provider Type %q", string(str)) +} + +type RemoteConfigError string + +func (rce RemoteConfigError) Error() string { + return fmt.Sprintf("Remote Configurations Error: %s", string(rce)) +} + +type viper struct { + // A set of paths to look for the config file in + configPaths []string + + // A set of remote providers to search for the configuration + remoteProviders []*remoteProvider + + // Name of file to look for inside the path + configName string + configFile string + configType string + + config map[string]interface{} + override map[string]interface{} + defaults map[string]interface{} + kvstore map[string]interface{} + pflags map[string]*pflag.Flag + env map[string]string + aliases map[string]string +} + +func New() *viper { + v := new(viper) + v.configName = "config" + v.config = make(map[string]interface{}) + v.override = make(map[string]interface{}) + v.defaults = make(map[string]interface{}) + v.kvstore = make(map[string]interface{}) + v.pflags = make(map[string]*pflag.Flag) + v.env = make(map[string]string) + v.aliases = make(map[string]string) + + return v +} + // remoteProvider stores the configuration necessary // to connect to a remote key/value store. // Optional secretKeyring to unencrypt encrypted values @@ -52,50 +110,34 @@ type remoteProvider struct { secretKeyring string } -// A set of paths to look for the config file in -var configPaths []string - -// A set of remote providers to search for the configuration -var remoteProviders []*remoteProvider - -// Name of file to look for inside the path -var configName string = "config" - -// extensions Supported +// universally supported extensions var SupportedExts []string = []string{"json", "toml", "yaml", "yml"} -var SupportedRemoteProviders []string = []string{"etcd", "consul"} -var configFile string -var configType string -var config map[string]interface{} = make(map[string]interface{}) -var override map[string]interface{} = make(map[string]interface{}) -var env map[string]string = make(map[string]string) -var defaults map[string]interface{} = make(map[string]interface{}) -var kvstore map[string]interface{} = make(map[string]interface{}) -var pflags map[string]*pflag.Flag = make(map[string]*pflag.Flag) -var aliases map[string]string = make(map[string]string) +// universally supported remote providers +var SupportedRemoteProviders []string = []string{"etcd", "consul"} // Explicitly define the path, name and extension of the config file // Viper will use this and not check any of the config paths -func SetConfigFile(in string) { +func SetConfigFile(in string) { v.SetConfigFile(in) } +func (v *viper) SetConfigFile(in string) { if in != "" { - configFile = in + v.configFile = in } } -func ConfigFileUsed() string { - return configFile -} +// Return the config file used +func ConfigFileUsed() string { return v.ConfigFileUsed() } +func (v *viper) ConfigFileUsed() string { return v.configFile } // Add a path for viper to search for the config file in. // Can be called multiple times to define multiple search paths. - -func AddConfigPath(in string) { +func AddConfigPath(in string) { v.AddConfigPath(in) } +func (v *viper) AddConfigPath(in string) { if in != "" { absin := absPathify(in) jww.INFO.Println("adding", absin, "to paths to search") - if !stringInSlice(absin, configPaths) { - configPaths = append(configPaths, absin) + if !stringInSlice(absin, v.configPaths) { + v.configPaths = append(v.configPaths, absin) } } } @@ -109,6 +151,9 @@ func AddConfigPath(in string) { // you should set path to /configs and set config name (SetConfigName()) to // "myapp" func AddRemoteProvider(provider, endpoint, path string) error { + return v.AddRemoteProvider(provider, endpoint, path) +} +func (v *viper) AddRemoteProvider(provider, endpoint, path string) error { if !stringInSlice(provider, SupportedRemoteProviders) { return UnsupportedRemoteProviderError(provider) } @@ -119,8 +164,8 @@ func AddRemoteProvider(provider, endpoint, path string) error { provider: provider, path: path, } - if !providerPathExists(rp) { - remoteProviders = append(remoteProviders, rp) + if !v.providerPathExists(rp) { + v.remoteProviders = append(v.remoteProviders, rp) } } return nil @@ -137,6 +182,10 @@ func AddRemoteProvider(provider, endpoint, path string) error { // "myapp" // Secure Remote Providers are implemented with github.com/xordataexchange/crypt func AddSecureRemoteProvider(provider, endpoint, path, secretkeyring string) error { + return v.AddSecureRemoteProvider(provider, endpoint, path, secretkeyring) +} + +func (v *viper) AddSecureRemoteProvider(provider, endpoint, path, secretkeyring string) error { if !stringInSlice(provider, SupportedRemoteProviders) { return UnsupportedRemoteProviderError(provider) } @@ -147,16 +196,15 @@ func AddSecureRemoteProvider(provider, endpoint, path, secretkeyring string) err provider: provider, path: path, } - if !providerPathExists(rp) { - remoteProviders = append(remoteProviders, rp) + if !v.providerPathExists(rp) { + v.remoteProviders = append(v.remoteProviders, rp) } } return nil } -func providerPathExists(p *remoteProvider) bool { - - for _, y := range remoteProviders { +func (v *viper) providerPathExists(p *remoteProvider) bool { + for _, y := range v.remoteProviders { if reflect.DeepEqual(y, p) { return true } @@ -164,69 +212,73 @@ func providerPathExists(p *remoteProvider) bool { return false } -type UnsupportedRemoteProviderError string - -func (str UnsupportedRemoteProviderError) Error() string { - return fmt.Sprintf("Unsupported Remote Provider Type %q", string(str)) -} - -func GetString(key string) string { +func GetString(key string) string { return v.GetString(key) } +func (v *viper) GetString(key string) string { return cast.ToString(Get(key)) } -func GetBool(key string) bool { +func GetBool(key string) bool { return v.GetBool(key) } +func (v *viper) GetBool(key string) bool { return cast.ToBool(Get(key)) } -func GetInt(key string) int { +func GetInt(key string) int { return v.GetInt(key) } +func (v *viper) GetInt(key string) int { return cast.ToInt(Get(key)) } -func GetFloat64(key string) float64 { +func GetFloat64(key string) float64 { return v.GetFloat64(key) } +func (v *viper) GetFloat64(key string) float64 { return cast.ToFloat64(Get(key)) } -func GetTime(key string) time.Time { +func GetTime(key string) time.Time { return v.GetTime(key) } +func (v *viper) GetTime(key string) time.Time { return cast.ToTime(Get(key)) } -func GetStringSlice(key string) []string { +func GetStringSlice(key string) []string { return v.GetStringSlice(key) } +func (v *viper) GetStringSlice(key string) []string { return cast.ToStringSlice(Get(key)) } -func GetStringMap(key string) map[string]interface{} { +func GetStringMap(key string) map[string]interface{} { return v.GetStringMap(key) } +func (v *viper) GetStringMap(key string) map[string]interface{} { return cast.ToStringMap(Get(key)) } -func GetStringMapString(key string) map[string]string { +func GetStringMapString(key string) map[string]string { return v.GetStringMapString(key) } +func (v *viper) GetStringMapString(key string) map[string]string { return cast.ToStringMapString(Get(key)) } // Takes a single key and marshals it into a Struct -func MarshalKey(key string, rawVal interface{}) error { +func MarshalKey(key string, rawVal interface{}) error { return v.MarshalKey(key, rawVal) } +func (v *viper) MarshalKey(key string, rawVal interface{}) error { return mapstructure.Decode(Get(key), rawVal) } // Marshals the config into a Struct -func Marshal(rawVal interface{}) error { - err := mapstructure.Decode(defaults, rawVal) +func Marshal(rawVal interface{}) error { return v.Marshal(rawVal) } +func (v *viper) Marshal(rawVal interface{}) error { + err := mapstructure.Decode(v.defaults, rawVal) if err != nil { return err } - err = mapstructure.Decode(config, rawVal) + err = mapstructure.Decode(v.config, rawVal) if err != nil { return err } - err = mapstructure.Decode(override, rawVal) + err = mapstructure.Decode(v.override, rawVal) if err != nil { return err } - err = mapstructure.Decode(kvstore, rawVal) + err = mapstructure.Decode(v.kvstore, rawVal) if err != nil { return err } - insensativiseMaps() + v.insensativiseMaps() return nil } @@ -236,11 +288,12 @@ func Marshal(rawVal interface{}) error { // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") // viper.BindPFlag("port", serverCmd.Flags().Lookup("port")) // -func BindPFlag(key string, flag *pflag.Flag) (err error) { +func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) } +func (v *viper) BindPFlag(key string, flag *pflag.Flag) (err error) { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } - pflags[strings.ToLower(key)] = flag + v.pflags[strings.ToLower(key)] = flag switch flag.Value.Type() { case "int", "int8", "int16", "int32", "int64": @@ -256,7 +309,8 @@ func BindPFlag(key string, flag *pflag.Flag) (err error) { // Binds a viper key to a ENV variable // ENV variables are case sensitive // If only a key is provided, it will use the env key matching the key, uppercased. -func BindEnv(input ...string) (err error) { +func BindEnv(input ...string) (err error) { return v.BindEnv(input...) } +func (v *viper) BindEnv(input ...string) (err error) { var key, envkey string if len(input) == 0 { return fmt.Errorf("BindEnv missing key to bind to") @@ -270,20 +324,20 @@ func BindEnv(input ...string) (err error) { envkey = input[1] } - env[key] = envkey + v.env[key] = envkey return nil } -func find(key string) interface{} { +func (v *viper) find(key string) interface{} { var val interface{} var exists bool // if the requested key is an alias, then return the proper key - key = realKey(key) + key = v.realKey(key) // PFlag Override first - flag, exists := pflags[key] + flag, exists := v.pflags[key] if exists { if flag.Changed { jww.TRACE.Println(key, "found in override (via pflag):", val) @@ -291,13 +345,13 @@ func find(key string) interface{} { } } - val, exists = override[key] + val, exists = v.override[key] if exists { jww.TRACE.Println(key, "found in override:", val) return val } - envkey, exists := env[key] + envkey, exists := v.env[key] if exists { jww.TRACE.Println(key, "registered as env var", envkey) if val = os.Getenv(envkey); val != "" { @@ -308,19 +362,19 @@ func find(key string) interface{} { } } - val, exists = config[key] + val, exists = v.config[key] if exists { jww.TRACE.Println(key, "found in config:", val) return val } - val, exists = kvstore[key] + val, exists = v.kvstore[key] if exists { jww.TRACE.Println(key, "found in key/value store:", val) return val } - val, exists = defaults[key] + val, exists = v.defaults[key] if exists { jww.TRACE.Println(key, "found in defaults:", val) return val @@ -331,151 +385,157 @@ func find(key string) interface{} { // Get returns an interface.. // Must be typecast or used by something that will typecast -func Get(key string) interface{} { +func Get(key string) interface{} { return v.Get(key) } +func (v *viper) Get(key string) interface{} { key = strings.ToLower(key) - v := find(key) + val := v.find(key) - if v == nil { + if val == nil { return nil } - switch v.(type) { + switch val.(type) { case bool: - return cast.ToBool(v) + return cast.ToBool(val) case string: - return cast.ToString(v) + return cast.ToString(val) case int64, int32, int16, int8, int: - return cast.ToInt(v) + return cast.ToInt(val) case float64, float32: - return cast.ToFloat64(v) + return cast.ToFloat64(val) case time.Time: - return cast.ToTime(v) + return cast.ToTime(val) case []string: - return v + return val } - return v + return val } -func IsSet(key string) bool { - t := Get(key) +func IsSet(key string) bool { return v.IsSet(key) } +func (v *viper) IsSet(key string) bool { + t := v.Get(key) return t != nil } // Have viper check ENV variables for all // keys set in config, default & flags -func AutomaticEnv() { - for _, x := range AllKeys() { - BindEnv(x) +func AutomaticEnv() { v.AutomaticEnv() } +func (v *viper) AutomaticEnv() { + for _, x := range v.AllKeys() { + v.BindEnv(x) } } // Aliases provide another accessor for the same key. // This enables one to change a name without breaking the application -func RegisterAlias(alias string, key string) { - registerAlias(alias, strings.ToLower(key)) +func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } +func (v *viper) RegisterAlias(alias string, key string) { + v.registerAlias(alias, strings.ToLower(key)) } -func registerAlias(alias string, key string) { +func (v *viper) registerAlias(alias string, key string) { alias = strings.ToLower(alias) - if alias != key && alias != realKey(key) { - _, exists := aliases[alias] + if alias != key && alias != v.realKey(key) { + _, exists := v.aliases[alias] if !exists { // if we alias something that exists in one of the maps to another // name, we'll never be able to get that value using the original // name, so move the config value to the new realkey. - if val, ok := config[alias]; ok { - delete(config, alias) - config[key] = val + if val, ok := v.config[alias]; ok { + delete(v.config, alias) + v.config[key] = val } - if val, ok := kvstore[alias]; ok { - delete(kvstore, alias) - kvstore[key] = val + if val, ok := v.kvstore[alias]; ok { + delete(v.kvstore, alias) + v.kvstore[key] = val } - if val, ok := defaults[alias]; ok { - delete(defaults, alias) - defaults[key] = val + if val, ok := v.defaults[alias]; ok { + delete(v.defaults, alias) + v.defaults[key] = val } - if val, ok := override[alias]; ok { - delete(override, alias) - override[key] = val + if val, ok := v.override[alias]; ok { + delete(v.override, alias) + v.override[key] = val } - aliases[alias] = key + v.aliases[alias] = key } } else { - jww.WARN.Println("Creating circular reference alias", alias, key, realKey(key)) + jww.WARN.Println("Creating circular reference alias", alias, key, v.realKey(key)) } } -func realKey(key string) string { - newkey, exists := aliases[key] +func (v *viper) realKey(key string) string { + newkey, exists := v.aliases[key] if exists { jww.DEBUG.Println("Alias", key, "to", newkey) - return realKey(newkey) + return v.realKey(newkey) } else { return key } } -func InConfig(key string) bool { +func InConfig(key string) bool { return v.InConfig(key) } +func (v *viper) InConfig(key string) bool { // if the requested key is an alias, then return the proper key - key = realKey(key) + key = v.realKey(key) - _, exists := config[key] + _, exists := v.config[key] return exists } // Set the default value for this key. // Default only used when no value is provided by the user via flag, config or ENV. -func SetDefault(key string, value interface{}) { +func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } +func (v *viper) SetDefault(key string, value interface{}) { // If alias passed in, then set the proper default - key = realKey(strings.ToLower(key)) - defaults[key] = value + key = v.realKey(strings.ToLower(key)) + v.defaults[key] = value } // The user provided value (via flag) // Will be used instead of values obtained via // config file, ENV, default, or key/value store -func Set(key string, value interface{}) { +func Set(key string, value interface{}) { v.Set(key, value) } +func (v *viper) Set(key string, value interface{}) { // If alias passed in, then set the proper override - key = realKey(strings.ToLower(key)) - override[key] = value -} - -type UnsupportedConfigError string - -func (str UnsupportedConfigError) Error() string { - return fmt.Sprintf("Unsupported Config Type %q", string(str)) + key = v.realKey(strings.ToLower(key)) + v.override[key] = value } // Viper will discover and load the configuration file from disk // and key/value stores, searching in one of the defined paths. -func ReadInConfig() error { +func ReadInConfig() error { return v.ReadInConfig() } +func (v *viper) ReadInConfig() error { jww.INFO.Println("Attempting to read in config file") - if !stringInSlice(getConfigType(), SupportedExts) { - return UnsupportedConfigError(getConfigType()) + if !stringInSlice(v.getConfigType(), SupportedExts) { + return UnsupportedConfigError(v.getConfigType()) } - file, err := ioutil.ReadFile(getConfigFile()) + file, err := ioutil.ReadFile(v.getConfigFile()) if err != nil { return err } - MarshallReader(bytes.NewReader(file), config) + v.MarshallReader(bytes.NewReader(file), v.config) return nil } -func ReadRemoteConfig() error { - err := getKeyValueConfig() + +func ReadRemoteConfig() error { return v.ReadRemoteConfig() } +func (v *viper) ReadRemoteConfig() error { + err := v.getKeyValueConfig() if err != nil { return err } return nil } -func MarshallReader(in io.Reader, c map[string]interface{}) { + +func MarshallReader(in io.Reader, c map[string]interface{}) { v.MarshallReader(in, c) } +func (v *viper) MarshallReader(in io.Reader, c map[string]interface{}) { buf := new(bytes.Buffer) buf.ReadFrom(in) - switch getConfigType() { + switch v.getConfigType() { case "yaml", "yml": if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { jww.ERROR.Fatalf("Error parsing config: %s", err) @@ -495,33 +555,27 @@ func MarshallReader(in io.Reader, c map[string]interface{}) { insensativiseMap(c) } -func insensativiseMaps() { - insensativiseMap(config) - insensativiseMap(defaults) - insensativiseMap(override) - insensativiseMap(kvstore) +func (v *viper) insensativiseMaps() { + insensativiseMap(v.config) + insensativiseMap(v.defaults) + insensativiseMap(v.override) + insensativiseMap(v.kvstore) } // retrieve the first found remote configuration -func getKeyValueConfig() error { - for _, rp := range remoteProviders { - val, err := getRemoteConfig(rp) +func (v *viper) getKeyValueConfig() error { + for _, rp := range v.remoteProviders { + val, err := v.getRemoteConfig(rp) if err != nil { continue } - kvstore = val + v.kvstore = val return nil } return RemoteConfigError("No Files Found") } -type RemoteConfigError string - -func (rce RemoteConfigError) Error() string { - return fmt.Sprintf("Remote Configurations Error: %s", string(rce)) -} - -func getRemoteConfig(provider *remoteProvider) (map[string]interface{}, error) { +func (v *viper) getRemoteConfig(provider *remoteProvider) (map[string]interface{}, error) { var cm crypt.ConfigManager var err error @@ -551,36 +605,27 @@ func getRemoteConfig(provider *remoteProvider) (map[string]interface{}, error) { return nil, err } reader := bytes.NewReader(b) - MarshallReader(reader, kvstore) - return kvstore, err + v.MarshallReader(reader, v.kvstore) + return v.kvstore, err } -func insensativiseMap(m map[string]interface{}) { - for key, val := range m { - lower := strings.ToLower(key) - if key != lower { - delete(m, key) - m[lower] = val - } - } -} - -func AllKeys() []string { +func AllKeys() []string { return v.AllKeys() } +func (v *viper) AllKeys() []string { m := map[string]struct{}{} - for key, _ := range defaults { + for key, _ := range v.defaults { m[key] = struct{}{} } - for key, _ := range config { + for key, _ := range v.config { m[key] = struct{}{} } - for key, _ := range kvstore { + for key, _ := range v.kvstore { m[key] = struct{}{} } - for key, _ := range override { + for key, _ := range v.override { m[key] = struct{}{} } @@ -592,10 +637,11 @@ func AllKeys() []string { return a } -func AllSettings() map[string]interface{} { +func AllSettings() map[string]interface{} { return v.AllSettings() } +func (v *viper) AllSettings() map[string]interface{} { m := map[string]interface{}{} - for _, x := range AllKeys() { - m[x] = Get(x) + for _, x := range v.AllKeys() { + m[x] = v.Get(x) } return m @@ -603,24 +649,26 @@ func AllSettings() map[string]interface{} { // Name for the config file. // Does not include extension. -func SetConfigName(in string) { +func SetConfigName(in string) { v.SetConfigName(in) } +func (v *viper) SetConfigName(in string) { if in != "" { - configName = in + v.configName = in } } -func SetConfigType(in string) { +func SetConfigType(in string) { v.SetConfigType(in) } +func (v *viper) SetConfigType(in string) { if in != "" { - configType = in + v.configType = in } } -func getConfigType() string { - if configType != "" { - return configType +func (v *viper) getConfigType() string { + if v.configType != "" { + return v.configType } - cf := getConfigFile() + cf := v.getConfigFile() ext := filepath.Ext(cf) if len(ext) > 1 { @@ -630,169 +678,78 @@ func getConfigType() string { } } -func getConfigFile() string { +func (v *viper) getConfigFile() string { // if explicitly set, then use it - if configFile != "" { - return configFile + if v.configFile != "" { + return v.configFile } - cf, err := findConfigFile() + cf, err := v.findConfigFile() if err != nil { return "" } - configFile = cf - return getConfigFile() + v.configFile = cf + return v.getConfigFile() } -func searchInPath(in string) (filename string) { +func (v *viper) searchInPath(in string) (filename string) { jww.DEBUG.Println("Searching for config in ", in) for _, ext := range SupportedExts { - - jww.DEBUG.Println("Checking for", filepath.Join(in, configName+"."+ext)) - if b, _ := exists(filepath.Join(in, configName+"."+ext)); b { - jww.DEBUG.Println("Found: ", filepath.Join(in, configName+"."+ext)) - return filepath.Join(in, configName+"."+ext) + jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext)) + if b, _ := exists(filepath.Join(in, v.configName+"."+ext)); b { + jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext)) + return filepath.Join(in, v.configName+"."+ext) } } return "" } -func findConfigFile() (string, error) { - jww.INFO.Println("Searching for config in ", configPaths) +func (v *viper) findConfigFile() (string, error) { + jww.INFO.Println("Searching for config in ", v.configPaths) - for _, cp := range configPaths { - file := searchInPath(cp) + for _, cp := range v.configPaths { + file := v.searchInPath(cp) if file != "" { return file, nil } } cwd, _ := findCWD() - file := searchInPath(cwd) + file := v.searchInPath(cwd) if file != "" { return file, nil } // try the current working directory wd, _ := os.Getwd() - file = searchInPath(wd) + file = v.searchInPath(wd) if file != "" { return file, nil } - return "", fmt.Errorf("config file not found in: %s", configPaths) + return "", fmt.Errorf("config file not found in: %s", v.configPaths) } -func findCWD() (string, error) { - serverFile, err := filepath.Abs(os.Args[0]) - - if err != nil { - return "", fmt.Errorf("Can't get absolute path for executable: %v", err) - } - - path := filepath.Dir(serverFile) - realFile, err := filepath.EvalSymlinks(serverFile) - - if err != nil { - if _, err = os.Stat(serverFile + ".exe"); err == nil { - realFile = filepath.Clean(serverFile + ".exe") - } - } - - if err == nil && realFile != serverFile { - path = filepath.Dir(realFile) - } - - return path, nil -} - -// Check if File / Directory Exists -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -func stringInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -func userHomeDir() string { - if runtime.GOOS == "windows" { - home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") - if home == "" { - home = os.Getenv("USERPROFILE") - } - return home - } - return os.Getenv("HOME") -} - -func absPathify(inPath string) string { - jww.INFO.Println("Trying to resolve absolute path to", inPath) - - if strings.HasPrefix(inPath, "$HOME") { - inPath = userHomeDir() + inPath[5:] - } - - if strings.HasPrefix(inPath, "$") { - end := strings.Index(inPath, string(os.PathSeparator)) - inPath = os.Getenv(inPath[1:end]) + inPath[end:] - } - - if filepath.IsAbs(inPath) { - return filepath.Clean(inPath) - } - - p, err := filepath.Abs(inPath) - if err == nil { - return filepath.Clean(p) - } else { - jww.ERROR.Println("Couldn't discover absolute path") - jww.ERROR.Println(err) - } - return "" -} - -func Debug() { +func Debug() { v.Debug() } +func (v *viper) Debug() { fmt.Println("Config:") - pretty.Println(config) + pretty.Println(v.config) fmt.Println("Key/Value Store:") - pretty.Println(kvstore) + pretty.Println(v.kvstore) fmt.Println("Env:") - pretty.Println(env) + pretty.Println(v.env) fmt.Println("Defaults:") - pretty.Println(defaults) + pretty.Println(v.defaults) fmt.Println("Override:") - pretty.Println(override) + pretty.Println(v.override) fmt.Println("Aliases:") - pretty.Println(aliases) + pretty.Println(v.aliases) } +// Intended for testing, will reset all to default settings. func Reset() { - configPaths = nil - configName = "config" - - // extensions Supported + v = New() SupportedExts = []string{"json", "toml", "yaml", "yml"} - configFile = "" - configType = "" - - kvstore = make(map[string]interface{}) - config = make(map[string]interface{}) - override = make(map[string]interface{}) - env = make(map[string]string) - defaults = make(map[string]interface{}) - aliases = make(map[string]string) + SupportedRemoteProviders = []string{"etcd", "consul"} } diff --git a/viper_test.go b/viper_test.go index bc4059f..3795ed9 100644 --- a/viper_test.go +++ b/viper_test.go @@ -83,7 +83,7 @@ func (s *stringValue) String() string { func TestBasics(t *testing.T) { SetConfigFile("/tmp/config.yaml") - assert.Equal(t, "/tmp/config.yaml", getConfigFile()) + assert.Equal(t, "/tmp/config.yaml", v.getConfigFile()) } func TestDefault(t *testing.T) { @@ -95,7 +95,7 @@ func TestMarshalling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample) - MarshallReader(r, config) + MarshallReader(r, v.config) assert.True(t, InConfig("name")) assert.False(t, InConfig("state")) assert.Equal(t, "steve", Get("name")) @@ -136,7 +136,7 @@ func TestYML(t *testing.T) { SetConfigType("yml") r := bytes.NewReader(yamlExample) - MarshallReader(r, config) + MarshallReader(r, v.config) assert.Equal(t, "steve", Get("name")) } @@ -144,7 +144,7 @@ func TestJSON(t *testing.T) { SetConfigType("json") r := bytes.NewReader(jsonExample) - MarshallReader(r, config) + MarshallReader(r, v.config) assert.Equal(t, "0001", Get("id")) } @@ -152,17 +152,17 @@ func TestTOML(t *testing.T) { SetConfigType("toml") r := bytes.NewReader(tomlExample) - MarshallReader(r, config) + MarshallReader(r, v.config) assert.Equal(t, "TOML Example", Get("title")) } func TestRemotePrecedence(t *testing.T) { SetConfigType("json") r := bytes.NewReader(jsonExample) - MarshallReader(r, config) + MarshallReader(r, v.config) remote := bytes.NewReader(remoteExample) assert.Equal(t, "0001", Get("id")) - MarshallReader(remote, kvstore) + MarshallReader(remote, v.kvstore) assert.Equal(t, "0001", Get("id")) assert.NotEqual(t, "cronut", Get("type")) assert.Equal(t, "remote", Get("newkey")) @@ -175,7 +175,7 @@ func TestRemotePrecedence(t *testing.T) { func TestEnv(t *testing.T) { SetConfigType("json") r := bytes.NewReader(jsonExample) - MarshallReader(r, config) + MarshallReader(r, v.config) BindEnv("id") BindEnv("f", "FOOD")