diff --git a/util.go b/util.go index 3ebada9..9d7d6d5 100644 --- a/util.go +++ b/util.go @@ -21,12 +21,14 @@ import ( "strings" "unicode" + yaml "gopkg.in/yaml.v2" + "github.com/hashicorp/hcl" "github.com/magiconair/properties" toml "github.com/pelletier/go-toml" + "github.com/spf13/afero" "github.com/spf13/cast" jww "github.com/spf13/jwalterweatherman" - "gopkg.in/yaml.v2" ) // ConfigParseError denotes failing to parse configuration file. @@ -39,6 +41,16 @@ func (pe ConfigParseError) Error() string { return fmt.Sprintf("While parsing config: %s", pe.err.Error()) } +// ConfigMarshalError happens when failing to marshal the configuration. +type ConfigMarshalError struct { + err error +} + +// Error returns the formatted configuration error. +func (e ConfigMarshalError) Error() string { + return fmt.Sprintf("While marshaling config: %s", e.err.Error()) +} + // toCaseInsensitiveValue checks if the value is a map; // if so, create a copy and lower-case the keys recursively. func toCaseInsensitiveValue(value interface{}) interface{} { @@ -152,6 +164,39 @@ func userHomeDir() string { return os.Getenv("HOME") } +func marshalConfigWriter(out afero.File, c map[string]interface{}, configType string) error { + switch configType { + case "json": + b, err := json.MarshalIndent(v.AllSettings(), "", " ") + if err != nil { + return ConfigMarshalError{err} + } + _, err = out.WriteString(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + + // case "toml": + // w := bufio.NewWriter(out) + // if err := toml.NewEncoder(w).Encode(v.AllSettings()); err != nil { + // return ConfigMarshalError{err} + // } + // w.Flush() + + case "yaml", "yml": + b, err := yaml.Marshal(v.AllSettings()) + if err != nil { + return ConfigMarshalError{err} + } + _, err = out.WriteString(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + } + + return nil +} + func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error { buf := new(bytes.Buffer) buf.ReadFrom(in) diff --git a/viper.go b/viper.go index 04f55ea..a07d8f5 100644 --- a/viper.go +++ b/viper.go @@ -20,9 +20,7 @@ package viper import ( - "bufio" "bytes" - "encoding/json" "fmt" "io" "log" @@ -32,8 +30,6 @@ import ( "strings" "time" - "github.com/BurntSushi/toml" - "github.com/fsnotify/fsnotify" "github.com/mitchellh/mapstructure" "github.com/spf13/afero" "github.com/spf13/cast" @@ -1044,50 +1040,6 @@ func (v *Viper) InConfig(key string) bool { return exists } -// Save configuration to file -func SaveConfig() error { return v.SaveConfig() } -func (v *Viper) SaveConfig() error { - - jww.INFO.Println("Attempting to write config into the file.") - if !stringInSlice(v.getConfigType(), SupportedExts) { - return UnsupportedConfigError(v.getConfigType()) - } - - f, err := os.Create(v.getConfigFile()) - if err != nil { - return err - } - defer f.Close() - - switch v.getConfigType() { - case "json": - - b, err := json.MarshalIndent(v.AllSettings(), "", " ") - if err != nil { - jww.FATAL.Println("Panic while encoding into JSON format.") - } - f.WriteString(string(b)) - - case "toml": - - w := bufio.NewWriter(f) - if err := toml.NewEncoder(w).Encode(v.AllSettings()); err != nil { - jww.FATAL.Println("Panic while encoding into TOML format.") - } - w.Flush() - - case "yaml", "yml": - - b, err := yaml.Marshal(v.AllSettings()) - if err != nil { - jww.FATAL.Println("Panic while encoding into YAML format.") - } - f.WriteString(string(b)) - } - - return nil -} - // SetDefault sets the default value for this key. // SetDefault is case-insensitive for a key. // Default only used when no value is provided by the user via flag, config or ENV. @@ -1190,6 +1142,31 @@ func (v *Viper) MergeConfig(in io.Reader) error { return nil } +// WriteConfig writes the current configuration to a file. +func WriteConfig() error { return v.WriteConfig() } +func (v *Viper) WriteConfig() error { + + jww.INFO.Println("Attempting to write config into the file.") + // filename, err := v.getConfigFile() + filename := "out.yaml" + // if err != nil { + // return err + // } + if !stringInSlice(v.getConfigType(), SupportedExts) { + return UnsupportedConfigError(v.getConfigType()) + } + if v.config == nil { + v.config = make(map[string]interface{}) + } + + file, err := v.fs.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.FileMode(0644)) + if err != nil { + return err + } + + return v.marshalWriter(file, v.config) +} + func keyExists(k string, m map[string]interface{}) string { lk := strings.ToLower(k) for mk := range m { @@ -1308,6 +1285,15 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { return unmarshallConfigReader(in, c, v.getConfigType()) } +// Marshal a map into Writer. +func marshalWriter(out afero.File, c map[string]interface{}) error { + return v.marshalWriter(out, c) +} + +func (v *Viper) marshalWriter(out afero.File, c map[string]interface{}) error { + return marshalConfigWriter(out, c, v.getConfigType()) +} + func (v *Viper) insensitiviseMaps() { insensitiviseMap(v.config) insensitiviseMap(v.defaults)