diff --git a/viper.go b/viper.go index 799983f..f46787c 100644 --- a/viper.go +++ b/viper.go @@ -79,6 +79,16 @@ func (rce RemoteConfigError) Error() string { return fmt.Sprintf("Remote Configurations Error: %s", string(rce)) } +// Denotes failing to find configuration file. +type ConfigFileNotFoundError struct { + name, locations string +} + +// Returns the formatted configuration error. +func (fnfe ConfigFileNotFoundError) Error() string { + return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations) +} + // Viper is a prioritized configuration registry. It // maintains a set of configuration sources, fetches // values to populate those, and provides them according @@ -937,9 +947,22 @@ func (v *Viper) searchInPath(in string) (filename string) { return "" } -// search all configPaths for any config file. -// Returns the first path that exists (and is a config file) +// Choose where to look for a config file: either +// in provided directories or in the working directory func (v *Viper) findConfigFile() (string, error) { + + if len(v.configPaths) > 0 { + return v.findConfigInPaths() + } else { + return v.findConfigInCWD() + } + +} + +// Search all configPaths for any config file. +// Returns the first path that exists (and has a config file) +func (v *Viper) findConfigInPaths() (string, error) { + jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths { @@ -948,14 +971,20 @@ func (v *Viper) findConfigFile() (string, error) { return file, nil } } + return "", ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)} +} + +// Search the current working directory for any config file. +func (v *Viper) findConfigInCWD() (string, error) { - // try the current working directory wd, _ := os.Getwd() + jww.INFO.Println("Searching for config in ", wd) + file := v.searchInPath(wd) if file != "" { return file, nil } - return "", fmt.Errorf("config file not found in: %s", v.configPaths) + return "", ConfigFileNotFoundError{v.configName, wd} } // Prints all configuration registries for debugging diff --git a/viper_test.go b/viper_test.go index 7ad0245..60d2f22 100644 --- a/viper_test.go +++ b/viper_test.go @@ -8,7 +8,10 @@ package viper import ( "bytes" "fmt" + "io/ioutil" "os" + "path" + "reflect" "sort" "strings" "testing" @@ -124,6 +127,47 @@ func initTOML() { marshalReader(r, v.config) } +// make directories for testing +func initDirs(t *testing.T) (string, string, func()) { + + var ( + testDirs = []string{`a a`, `b`, `c\c`, `D:`} + config = `improbable` + ) + + root, err := ioutil.TempDir("", "") + + cleanup := true + defer func() { + if cleanup { + os.Chdir("..") + os.RemoveAll(root) + } + }() + + assert.Nil(t, err) + + err = os.Chdir(root) + assert.Nil(t, err) + + err = ioutil.WriteFile(path.Join(root, config+".toml"), []byte("key = \"root\"\n"), 0640) + assert.Nil(t, err) + + for _, dir := range testDirs { + err = os.Mkdir(dir, 0750) + assert.Nil(t, err) + + err = ioutil.WriteFile(path.Join(dir, config+".toml"), []byte("key = \"value is "+dir+"\"\n"), 0640) + assert.Nil(t, err) + } + + cleanup = false + return root, config, func() { + os.Chdir("..") + os.RemoveAll(root) + } +} + //stubs for PFlag Values type stringValue string @@ -551,3 +595,100 @@ func TestReadBufConfig(t *testing.T) { assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim"}, v.Get("clothing")) assert.Equal(t, 35, v.Get("age")) } + +func TestCWDSearch(t *testing.T) { + + _, config, cleanup := initDirs(t) + defer cleanup() + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + err := v.ReadInConfig() + assert.Nil(t, err) + + assert.Equal(t, `root`, v.GetString(`key`)) +} + +func TestCWDSearchNoConfig(t *testing.T) { + + _, config, cleanup := initDirs(t) + defer cleanup() + + // Remove the config file in CWD + os.Remove(config + ".toml") + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + err := v.ReadInConfig() + assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err)) + + assert.Equal(t, `default`, v.GetString(`key`)) +} + +func TestDirsSearch(t *testing.T) { + + root, config, cleanup := initDirs(t) + defer cleanup() + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + entries, err := ioutil.ReadDir(root) + for _, e := range entries { + if e.IsDir() { + v.AddConfigPath(e.Name()) + } + } + + err = v.ReadInConfig() + assert.Nil(t, err) + + assert.Equal(t, `value is `+path.Base(v.configPaths[0]), v.GetString(`key`)) +} + +func TestWrongDirsSearchNotFoundHasCWDConfig(t *testing.T) { + + _, config, cleanup := initDirs(t) + defer cleanup() + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + v.AddConfigPath(`whattayoutalkingbout`) + v.AddConfigPath(`thispathaintthere`) + + err := v.ReadInConfig() + assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err)) + + // Should not see the value "root" which comes from config in CWD + assert.Equal(t, `default`, v.GetString(`key`)) +} + +func TestWrongDirsSearchNotFoundNoCWDConfig(t *testing.T) { + + _, config, cleanup := initDirs(t) + defer cleanup() + + // Remove the config file in CWD + os.Remove(config + ".toml") + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + v.AddConfigPath(`whattayoutalkingbout`) + v.AddConfigPath(`thispathaintthere`) + + err := v.ReadInConfig() + assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err)) + + // Even though config did not load and the error might have + // been ignored by the client, the default still loads + assert.Equal(t, `default`, v.GetString(`key`)) +}