Feature/write config (#287)

* Added method to write into TOML file.

* Added functionality to export configuration based on config type. The feature supports JSON and TOML.

* Added method to write into YAML file.

* Fixed the issue of incorrect defer and error checking order. The error checking must be first otherwise it will cause panic.

* Add WriteConfig methods

* Add support for toml

* Add shared write function and safe methods

* Fix incorrectly modified imports

* Remove extra comments

* Fix spelling

* Make marshal spelling consistent throughout

* Add support for remaining configuration types

This commit moves a significant portion of the code back to viper.go to
facilitate having access to the object when reading the files. The purpose is to
add properties to the viper object at read time, so that we can add the comments
back to the file when writing.

* Add tests for each written file type

* Modify test for updated HCL specification

* Modify to only support HCL write in Go 1.7

* Revert "Modify to only support HCL write in Go 1.7"

This reverts commit 12b34bc4eb92cbf8ebfd56b79519f448607e3e51.

* Need to truncate the file before writing

* Write all settings including overrides

* Use filename variable

* Lint remote.go

* Fix toml return count error
This commit is contained in:
Adam Sherwood 2017-12-06 20:26:31 -08:00 committed by Brian Ketelsen
parent 4dddf7c62e
commit 1a0c4a370c
4 changed files with 401 additions and 76 deletions

View file

@ -8,10 +8,11 @@ package remote
import ( import (
"bytes" "bytes"
"github.com/spf13/viper"
crypt "github.com/xordataexchange/crypt/config"
"io" "io"
"os" "os"
"github.com/spf13/viper"
crypt "github.com/xordataexchange/crypt/config"
) )
type remoteConfigProvider struct{} type remoteConfigProvider struct{}

62
util.go
View file

@ -11,23 +11,16 @@
package viper package viper
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"unicode" "unicode"
"github.com/hashicorp/hcl"
"github.com/magiconair/properties"
toml "github.com/pelletier/go-toml"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cast" "github.com/spf13/cast"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gopkg.in/yaml.v2"
) )
// ConfigParseError denotes failing to parse configuration file. // ConfigParseError denotes failing to parse configuration file.
@ -153,61 +146,6 @@ func userHomeDir() string {
return os.Getenv("HOME") return os.Getenv("HOME")
} }
func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
switch strings.ToLower(configType) {
case "yaml", "yml":
if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "json":
if err := json.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "hcl":
obj, err := hcl.Parse(string(buf.Bytes()))
if err != nil {
return ConfigParseError{err}
}
if err = hcl.DecodeObject(&c, obj); err != nil {
return ConfigParseError{err}
}
case "toml":
tree, err := toml.LoadReader(buf)
if err != nil {
return ConfigParseError{err}
}
tmap := tree.ToMap()
for k, v := range tmap {
c[k] = v
}
case "properties", "props", "prop":
var p *properties.Properties
var err error
if p, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil {
return ConfigParseError{err}
}
for _, key := range p.Keys() {
value, _ := p.Get(key)
// recursively build nested maps
path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
}
}
insensitiviseMap(c)
return nil
}
func safeMul(a, b uint) uint { func safeMul(a, b uint) uint {
c := a * b c := a * b
if a > 1 && b > 1 && c/b != a { if a > 1 && b > 1 && c/b != a {

223
viper.go
View file

@ -22,6 +22,7 @@ package viper
import ( import (
"bytes" "bytes"
"encoding/csv" "encoding/csv"
"encoding/json"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -31,14 +32,30 @@ import (
"strings" "strings"
"time" "time"
yaml "gopkg.in/yaml.v2"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/printer"
"github.com/magiconair/properties"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
toml "github.com/pelletier/go-toml"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cast" "github.com/spf13/cast"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"github.com/spf13/pflag" "github.com/spf13/pflag"
) )
// ConfigMarshalError happens when failing to marshal the configuration.
type ConfigMarshalError struct {
err error
}
// Error returns the formatted configuration error.
func (e ConfigMarshalError) Error() string {
return fmt.Sprintf("While marshaling config: %s", e.err.Error())
}
var v *Viper var v *Viper
type RemoteResponse struct { type RemoteResponse struct {
@ -162,6 +179,10 @@ type Viper struct {
aliases map[string]string aliases map[string]string
typeByDefValue bool typeByDefValue bool
// Store read properties on the object so that we can write back in order with comments.
// This will only be used if the configuration read is a properties file.
properties *properties.Properties
onConfigChange func(fsnotify.Event) onConfigChange func(fsnotify.Event)
} }
@ -188,7 +209,7 @@ func New() *Viper {
// can use it in their testing as well. // can use it in their testing as well.
func Reset() { func Reset() {
v = New() v = New()
SupportedExts = []string{"json", "toml", "yaml", "yml", "hcl"} SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl"}
SupportedRemoteProviders = []string{"etcd", "consul"} SupportedRemoteProviders = []string{"etcd", "consul"}
} }
@ -1119,6 +1140,7 @@ func (v *Viper) ReadInConfig() error {
return UnsupportedConfigError(v.getConfigType()) return UnsupportedConfigError(v.getConfigType())
} }
jww.DEBUG.Println("Reading file: ", filename)
file, err := afero.ReadFile(v.fs, filename) file, err := afero.ReadFile(v.fs, filename)
if err != nil { if err != nil {
return err return err
@ -1178,6 +1200,195 @@ func (v *Viper) MergeConfig(in io.Reader) error {
return nil return nil
} }
// WriteConfig writes the current configuration to a file.
func WriteConfig() error { return v.WriteConfig() }
func (v *Viper) WriteConfig() error {
filename, err := v.getConfigFile()
if err != nil {
return err
}
return v.writeConfig(filename, true)
}
// SafeWriteConfig writes current configuration to file only if the file does not exist.
func SafeWriteConfig() error { return v.SafeWriteConfig() }
func (v *Viper) SafeWriteConfig() error {
filename, err := v.getConfigFile()
if err != nil {
return err
}
return v.writeConfig(filename, false)
}
// WriteConfigAs writes current configuration to a given filename.
func WriteConfigAs(filename string) error { return v.WriteConfigAs(filename) }
func (v *Viper) WriteConfigAs(filename string) error {
return v.writeConfig(filename, true)
}
// SafeWriteConfigAs writes current configuration to a given filename if it does not exist.
func SafeWriteConfigAs(filename string) error { return v.SafeWriteConfigAs(filename) }
func (v *Viper) SafeWriteConfigAs(filename string) error {
return v.writeConfig(filename, false)
}
func writeConfig(filename string, force bool) error { return v.writeConfig(filename, force) }
func (v *Viper) writeConfig(filename string, force bool) error {
jww.INFO.Println("Attempting to write configuration to file.")
ext := filepath.Ext(filename)
if len(ext) <= 1 {
return fmt.Errorf("Filename: %s requires valid extension.", filename)
}
configType := ext[1:]
if !stringInSlice(configType, SupportedExts) {
return UnsupportedConfigError(configType)
}
if v.config == nil {
v.config = make(map[string]interface{})
}
var flags int
if force == true {
flags = os.O_CREATE | os.O_TRUNC | os.O_WRONLY
} else {
if _, err := os.Stat(filename); os.IsNotExist(err) {
flags = os.O_WRONLY
} else {
return fmt.Errorf("File: %s exists. Use WriteConfig to overwrite.", filename)
}
}
f, err := v.fs.OpenFile(filename, flags, os.FileMode(0644))
if err != nil {
return err
}
return v.marshalWriter(f, configType)
}
// Unmarshal a Reader into a map.
// Should probably be an unexported function.
func unmarshalReader(in io.Reader, c map[string]interface{}) error {
return v.unmarshalReader(in, c)
}
func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
switch strings.ToLower(v.getConfigType()) {
case "yaml", "yml":
if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "json":
if err := json.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "hcl":
obj, err := hcl.Parse(string(buf.Bytes()))
if err != nil {
return ConfigParseError{err}
}
if err = hcl.DecodeObject(&c, obj); err != nil {
return ConfigParseError{err}
}
case "toml":
tree, err := toml.LoadReader(buf)
if err != nil {
return ConfigParseError{err}
}
tmap := tree.ToMap()
for k, v := range tmap {
c[k] = v
}
case "properties", "props", "prop":
v.properties = properties.NewProperties()
var err error
if v.properties, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil {
return ConfigParseError{err}
}
for _, key := range v.properties.Keys() {
value, _ := v.properties.Get(key)
// recursively build nested maps
path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
}
}
insensitiviseMap(c)
return nil
}
// Marshal a map into Writer.
func marshalWriter(f afero.File, configType string) error {
return v.marshalWriter(f, configType)
}
func (v *Viper) marshalWriter(f afero.File, configType string) error {
c := v.AllSettings()
switch configType {
case "json":
b, err := json.MarshalIndent(c, "", " ")
if err != nil {
return ConfigMarshalError{err}
}
_, err = f.WriteString(string(b))
if err != nil {
return ConfigMarshalError{err}
}
case "hcl":
b, err := json.Marshal(c)
ast, err := hcl.Parse(string(b))
if err != nil {
return ConfigMarshalError{err}
}
err = printer.Fprint(f, ast.Node)
if err != nil {
return ConfigMarshalError{err}
}
case "prop", "props", "properties":
if v.properties == nil {
v.properties = properties.NewProperties()
}
p := v.properties
for _, key := range v.AllKeys() {
_, _, err := p.Set(key, v.GetString(key))
if err != nil {
return ConfigMarshalError{err}
}
}
_, err := p.WriteComment(f, "#", properties.UTF8)
if err != nil {
return ConfigMarshalError{err}
}
case "toml":
t, err := toml.TreeFromMap(c)
if err != nil {
return ConfigMarshalError{err}
}
s := t.String()
if _, err := f.WriteString(s); err != nil {
return ConfigMarshalError{err}
}
case "yaml", "yml":
b, err := yaml.Marshal(c)
if err != nil {
return ConfigMarshalError{err}
}
if _, err = f.WriteString(string(b)); err != nil {
return ConfigMarshalError{err}
}
}
return nil
}
func keyExists(k string, m map[string]interface{}) string { func keyExists(k string, m map[string]interface{}) string {
lk := strings.ToLower(k) lk := strings.ToLower(k)
for mk := range m { for mk := range m {
@ -1290,16 +1501,6 @@ func (v *Viper) WatchRemoteConfigOnChannel() error {
return v.watchKeyValueConfigOnChannel() return v.watchKeyValueConfigOnChannel()
} }
// Unmarshal a Reader into a map.
// Should probably be an unexported function.
func unmarshalReader(in io.Reader, c map[string]interface{}) error {
return v.unmarshalReader(in, c)
}
func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
return unmarshallConfigReader(in, c, v.getConfigType())
}
func (v *Viper) insensitiviseMaps() { func (v *Viper) insensitiviseMaps() {
insensitiviseMap(v.config) insensitiviseMap(v.config)
insensitiviseMap(v.defaults) insensitiviseMap(v.defaults)

View file

@ -18,6 +18,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/spf13/afero"
"github.com/spf13/cast" "github.com/spf13/cast"
"github.com/spf13/pflag" "github.com/spf13/pflag"
@ -262,7 +263,7 @@ func TestDefault(t *testing.T) {
assert.Equal(t, "leather", Get("clothing.jacket")) assert.Equal(t, "leather", Get("clothing.jacket"))
} }
func TestUnmarshalling(t *testing.T) { func TestUnmarshaling(t *testing.T) {
SetConfigType("yaml") SetConfigType("yaml")
r := bytes.NewReader(yamlExample) r := bytes.NewReader(yamlExample)
@ -847,6 +848,190 @@ func TestSub(t *testing.T) {
assert.Equal(t, (*Viper)(nil), subv) assert.Equal(t, (*Viper)(nil), subv)
} }
var hclWriteExpected = []byte(`"foos" = {
"foo" = {
"key" = 1
}
"foo" = {
"key" = 2
}
"foo" = {
"key" = 3
}
"foo" = {
"key" = 4
}
}
"id" = "0001"
"name" = "Cake"
"ppu" = 0.55
"type" = "donut"`)
func TestWriteConfigHCL(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("hcl")
err := v.ReadConfig(bytes.NewBuffer(hclExample))
if err != nil {
t.Fatal(err)
}
if err := v.WriteConfigAs("c.hcl"); err != nil {
t.Fatal(err)
}
read, err := afero.ReadFile(fs, "c.hcl")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, hclWriteExpected, read)
}
var jsonWriteExpected = []byte(`{
"batters": {
"batter": [
{
"type": "Regular"
},
{
"type": "Chocolate"
},
{
"type": "Blueberry"
},
{
"type": "Devil's Food"
}
]
},
"id": "0001",
"name": "Cake",
"ppu": 0.55,
"type": "donut"
}`)
func TestWriteConfigJson(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("json")
err := v.ReadConfig(bytes.NewBuffer(jsonExample))
if err != nil {
t.Fatal(err)
}
if err := v.WriteConfigAs("c.json"); err != nil {
t.Fatal(err)
}
read, err := afero.ReadFile(fs, "c.json")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, jsonWriteExpected, read)
}
var propertiesWriteExpected = []byte(`p_id = 0001
p_type = donut
p_name = Cake
p_ppu = 0.55
p_batters.batter.type = Regular
`)
func TestWriteConfigProperties(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("properties")
err := v.ReadConfig(bytes.NewBuffer(propertiesExample))
if err != nil {
t.Fatal(err)
}
if err := v.WriteConfigAs("c.properties"); err != nil {
t.Fatal(err)
}
read, err := afero.ReadFile(fs, "c.properties")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, propertiesWriteExpected, read)
}
func TestWriteConfigTOML(t *testing.T) {
fs := afero.NewMemMapFs()
v := New()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("toml")
err := v.ReadConfig(bytes.NewBuffer(tomlExample))
if err != nil {
t.Fatal(err)
}
if err := v.WriteConfigAs("c.toml"); err != nil {
t.Fatal(err)
}
// The TOML String method does not order the contents.
// Therefore, we must read the generated file and compare the data.
v2 := New()
v2.SetFs(fs)
v2.SetConfigName("c")
v2.SetConfigType("toml")
v2.SetConfigFile("c.toml")
err = v2.ReadInConfig()
if err != nil {
t.Fatal(err)
}
assert.Equal(t, v.GetString("title"), v2.GetString("title"))
assert.Equal(t, v.GetString("owner.bio"), v2.GetString("owner.bio"))
assert.Equal(t, v.GetString("owner.dob"), v2.GetString("owner.dob"))
assert.Equal(t, v.GetString("owner.organization"), v2.GetString("owner.organization"))
}
var yamlWriteExpected = []byte(`age: 35
beard: true
clothing:
jacket: leather
pants:
size: large
trousers: denim
eyes: brown
hacker: true
hobbies:
- skateboarding
- snowboarding
- go
name: steve
`)
func TestWriteConfigYAML(t *testing.T) {
v := New()
fs := afero.NewMemMapFs()
v.SetFs(fs)
v.SetConfigName("c")
v.SetConfigType("yaml")
err := v.ReadConfig(bytes.NewBuffer(yamlExample))
if err != nil {
t.Fatal(err)
}
if err := v.WriteConfigAs("c.yaml"); err != nil {
t.Fatal(err)
}
read, err := afero.ReadFile(fs, "c.yaml")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, yamlWriteExpected, read)
}
var yamlMergeExampleTgt = []byte(` var yamlMergeExampleTgt = []byte(`
hello: hello:
pop: 37890 pop: 37890