diff --git a/viper.go b/viper.go index f61f4ed..87cc8c6 100644 --- a/viper.go +++ b/viper.go @@ -190,8 +190,11 @@ type Viper struct { remoteProviders []*defaultRemoteProvider // Name of file to look for inside the path - configName string - configFile string + configName string + configFile struct { + path string + explicitlySet bool + } configType string configPermissions os.FileMode envPrefix string @@ -411,7 +414,8 @@ func (v *Viper) WatchConfig() { func SetConfigFile(in string) { v.SetConfigFile(in) } func (v *Viper) SetConfigFile(in string) { if in != "" { - v.configFile = in + v.configFile.path = in + v.configFile.explicitlySet = true } } @@ -460,7 +464,7 @@ func (v *Viper) getEnv(key string) (string, bool) { // ConfigFileUsed returns the file used to populate the config registry. func ConfigFileUsed() string { return v.ConfigFileUsed() } -func (v *Viper) ConfigFileUsed() string { return v.configFile } +func (v *Viper) ConfigFileUsed() string { return v.configFile.path } // AddConfigPath adds a path for Viper to search for the config file in. // Can be called multiple times to define multiple search paths. @@ -1407,20 +1411,32 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { // WriteConfig writes the current configuration to a file. func WriteConfig() error { return v.WriteConfig() } func (v *Viper) WriteConfig() error { - filename, err := v.getConfigFile() - if err != nil { - return err + if !v.configFile.explicitlySet { + _, err := v.getConfigFile() + if err != nil { + if _, ok := err.(ConfigFileNotFoundError); ok { + if len(v.configPaths) < 1 { + return errors.New("missing configuration for 'configPath'") + } + v.configFile.path = filepath.Join(v.configPaths[0], v.configName+"."+v.configType) + } else { + return err + } + } } - return v.writeConfig(filename, true) + return v.WriteConfigAs(v.configFile.path) } // SafeWriteConfig writes current configuration to file only if the file does not exist. func SafeWriteConfig() error { return v.SafeWriteConfig() } func (v *Viper) SafeWriteConfig() error { - if len(v.configPaths) < 1 { - return errors.New("missing configuration for 'configPath'") + if !v.configFile.explicitlySet { + if len(v.configPaths) < 1 { + return errors.New("missing configuration for 'configPath'") + } + v.configFile.path = filepath.Join(v.configPaths[0], v.configName+"."+v.configType) } - return v.SafeWriteConfigAs(filepath.Join(v.configPaths[0], v.configName+"."+v.configType)) + return v.SafeWriteConfigAs(v.configFile.path) } // WriteConfigAs writes current configuration to a given filename. @@ -1947,7 +1963,8 @@ func SetConfigName(in string) { v.SetConfigName(in) } func (v *Viper) SetConfigName(in string) { if in != "" { v.configName = in - v.configFile = "" + v.configFile.path = "" + v.configFile.explicitlySet = false } } @@ -1986,14 +2003,14 @@ func (v *Viper) getConfigType() string { } func (v *Viper) getConfigFile() (string, error) { - if v.configFile == "" { + if v.configFile.path == "" { cf, err := v.findConfigFile() if err != nil { return "", err } - v.configFile = cf + v.configFile.path = cf } - return v.configFile, nil + return v.configFile.path, nil } func (v *Viper) searchInPath(in string) (filename string) { diff --git a/viper_test.go b/viper_test.go index fe942de..f90885b 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1372,7 +1372,7 @@ hobbies: name: steve `) -func TestWriteConfig(t *testing.T) { +func TestWriteConfigAs(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { configName string @@ -1495,7 +1495,7 @@ func TestWriteConfig(t *testing.T) { } } -func TestWriteConfigTOML(t *testing.T) { +func TestWriteConfigAsTOML(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { @@ -1551,7 +1551,7 @@ func TestWriteConfigTOML(t *testing.T) { } } -func TestWriteConfigDotEnv(t *testing.T) { +func TestWriteConfigAsDotEnv(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { configName string @@ -1605,6 +1605,56 @@ func TestWriteConfigDotEnv(t *testing.T) { } } +func TestWriteConfig(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test") + v.SetConfigName("c") + v.SetConfigType("yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.WriteConfig()) + read, err := afero.ReadFile(fs, "/test/c.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + +func TestWriteConfigWithExplicitlySetFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test1") + v.SetConfigName("c1") + v.SetConfigType("yaml") + v.SetConfigFile("/test2/c2.yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.WriteConfig()) + read, err := afero.ReadFile(fs, "/test2/c2.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + +func TestWriteConfigWithMissingConfigPath(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("yaml") + require.EqualError(t, v.WriteConfig(), "missing configuration for 'configPath'") +} + +func TestWriteConfigWithExistingFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + fs.Create("/test/c.yaml") + v.SetFs(fs) + v.AddConfigPath("/test") + v.SetConfigName("c") + v.SetConfigType("yaml") + err := v.WriteConfig() + require.NoError(t, err) +} + func TestSafeWriteConfig(t *testing.T) { v := New() fs := afero.NewMemMapFs() @@ -1619,6 +1669,21 @@ func TestSafeWriteConfig(t *testing.T) { assert.Equal(t, yamlWriteExpected, read) } +func TestSafeWriteConfigWithExplicitlySetFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test1") + v.SetConfigName("c1") + v.SetConfigType("yaml") + v.SetConfigFile("/test2/c2.yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.SafeWriteConfig()) + read, err := afero.ReadFile(fs, "/test2/c2.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + func TestSafeWriteConfigWithMissingConfigPath(t *testing.T) { v := New() fs := afero.NewMemMapFs() @@ -1642,7 +1707,7 @@ func TestSafeWriteConfigWithExistingFile(t *testing.T) { assert.True(t, ok, "Expected ConfigFileAlreadyExistsError") } -func TestSafeWriteAsConfig(t *testing.T) { +func TestSafeWriteConfigAs(t *testing.T) { v := New() fs := afero.NewMemMapFs() v.SetFs(fs)