diff --git a/.travis.yml b/.travis.yml index e793edb..fa39805 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,9 +2,8 @@ go_import_path: github.com/spf13/viper language: go go: - - 1.5.4 - - 1.6.3 - - 1.7 + - 1.9.x + - 1.10.x - tip os: @@ -18,6 +17,7 @@ matrix: script: - go install ./... + - diff -u <(echo -n) <(gofmt -d .) - go test -v ./... after_success: diff --git a/README.md b/README.md index 25181df..64bf474 100644 --- a/README.md +++ b/README.md @@ -6,18 +6,19 @@ Many Go projects are built using Viper including: * [Hugo](http://gohugo.io) * [EMC RexRay](http://rexray.readthedocs.org/en/stable/) -* [Imgur's Incus](https://github.com/Imgur/incus) +* [Imgur’s Incus](https://github.com/Imgur/incus) * [Nanobox](https://github.com/nanobox-io/nanobox)/[Nanopack](https://github.com/nanopack) * [Docker Notary](https://github.com/docker/Notary) * [BloomApi](https://www.bloomapi.com/) * [doctl](https://github.com/digitalocean/doctl) +* [Clairctl](https://github.com/jgsqware/clairctl) [![Build Status](https://travis-ci.org/spf13/viper.svg)](https://travis-ci.org/spf13/viper) [![Join the chat at https://gitter.im/spf13/viper](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/spf13/viper?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![GoDoc](https://godoc.org/github.com/spf13/viper?status.svg)](https://godoc.org/github.com/spf13/viper) ## What is Viper? -Viper is a complete configuration solution for go applications including 12 factor apps. It is designed +Viper is a complete configuration solution for Go applications including 12-Factor apps. It is designed to work within an application, and can handle all types of configuration needs and formats. It supports: @@ -68,7 +69,7 @@ Viper configuration keys are case insensitive. ### Establishing Defaults A good configuration system will support default values. A default value is not -required for a key, but it's useful in the event that a key hasn’t been set via +required for a key, but it’s useful in the event that a key hasn’t been set via config file, environment variable, remote configuration or flag. Examples: @@ -116,10 +117,10 @@ Optionally you can provide a function for Viper to run each time a change occurs **Make sure you add all of the configPaths prior to calling `WatchConfig()`** ```go - viper.WatchConfig() - viper.OnConfigChange(func(e fsnotify.Event) { - fmt.Println("Config file changed:", e.Name) - }) +viper.WatchConfig() +viper.OnConfigChange(func(e fsnotify.Event) { + fmt.Println("Config file changed:", e.Name) +}) ``` ### Reading Config from io.Reader @@ -184,7 +185,7 @@ with ENV: * `AutomaticEnv()` * `BindEnv(string...) : error` * `SetEnvPrefix(string)` - * `SetEnvReplacer(string...) *strings.Replacer` + * `SetEnvKeyReplacer(string...) *strings.Replacer` _When working with ENV variables, it’s important to recognize that Viper treats ENV variables as case sensitive._ @@ -211,7 +212,7 @@ time a `viper.Get` request is made. It will apply the following rules. It will check for a environment variable with a name matching the key uppercased and prefixed with the `EnvPrefix` if set. -`SetEnvReplacer` allows you to use a `strings.Replacer` object to rewrite Env +`SetEnvKeyReplacer` allows you to use a `strings.Replacer` object to rewrite Env keys to an extent. This is useful if you want to use `-` or something in your `Get()` calls, but want your environmental variables to use `_` delimiters. An example of using it can be found in `viper_test.go`. @@ -236,7 +237,7 @@ Like `BindEnv`, the value is not set when the binding method is called, but when it is accessed. This means you can bind as early as you want, even in an `init()` function. -The `BindPFlag()` method provides this functionality. +For individual flags, the `BindPFlag()` method provides this functionality. Example: @@ -245,6 +246,19 @@ serverCmd.Flags().Int("port", 1138, "Port to run Application server on") viper.BindPFlag("port", serverCmd.Flags().Lookup("port")) ``` +You can also bind an existing set of pflags (pflag.FlagSet): + +Example: + +```go +pflag.Int("flagname", 1234, "help message for flagname") + +pflag.Parse() +viper.BindPFlags(pflag.CommandLine) + +i := viper.GetInt("flagname") // retrieve values from viper instead of pflag +``` + The use of [pflag](https://github.com/spf13/pflag/) in Viper does not preclude the use of other packages that use the [flag](https://golang.org/pkg/flag/) package from the standard library. The pflag package can handle the flags @@ -263,15 +277,23 @@ import ( ) func main() { + + // using standard library "flag" package + flag.Int("flagname", 1234, "help message for flagname") + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) pflag.Parse() - ... + viper.BindPFlags(pflag.CommandLine) + + i := viper.GetInt("flagname") // retrieve value from viper + + ... } ``` #### Flag interfaces -Viper provides two Go interfaces to bind other flag systems if you don't use `Pflags`. +Viper provides two Go interfaces to bind other flag systems if you don’t use `Pflags`. `FlagValue` represents a single flag. This is a very simple example on how to implement this interface: @@ -401,7 +423,7 @@ go func(){ ## Getting Values From Viper -In Viper, there are a few ways to get a value depending on the value's type. +In Viper, there are a few ways to get a value depending on the value’s type. The following functions and methods exist: * `Get(key string) : interface{}` @@ -531,7 +553,7 @@ func NewCache(cfg *Viper) *Cache {...} ``` which creates a cache based on config information formatted as `subv`. -Now it's easy to create these 2 caches separately as: +Now it’s easy to create these 2 caches separately as: ```go cfg1 := viper.Sub("app.cache1") @@ -575,13 +597,13 @@ initialization needed to begin using Viper. Since most applications will want to use a single central repository for their configuration, the viper package provides this. It is similar to a singleton. -In all of the examples above, they demonstrate using viper in it's singleton +In all of the examples above, they demonstrate using viper in its singleton style approach. ### Working with multiple vipers You can also create many different vipers for use in your application. Each will -have it’s own unique set of configurations and values. Each can read from a +have its own unique set of configurations and values. Each can read from a different config file, key value store, etc. All of the functions that viper package supports are mirrored as methods on a viper. diff --git a/flags_test.go b/flags_test.go index 5bffca3..0b976b6 100644 --- a/flags_test.go +++ b/flags_test.go @@ -62,5 +62,4 @@ func TestBindFlagValue(t *testing.T) { flag.Changed = true //hack for pflag usage assert.Equal(t, "testing_mutate", Get("testvalue")) - } diff --git a/nohup.out b/nohup.out deleted file mode 100644 index 8973bf2..0000000 --- a/nohup.out +++ /dev/null @@ -1 +0,0 @@ -QProcess::start: Process is already running diff --git a/remote/remote.go b/remote/remote.go index faaf3b3..810d070 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -8,10 +8,11 @@ package remote import ( "bytes" - "github.com/spf13/viper" - crypt "github.com/xordataexchange/crypt/config" "io" "os" + + "github.com/spf13/viper" + crypt "github.com/xordataexchange/crypt/config" ) type remoteConfigProvider struct{} @@ -33,17 +34,45 @@ func (rc remoteConfigProvider) Watch(rp viper.RemoteProvider) (io.Reader, error) if err != nil { return nil, err } - resp := <-cm.Watch(rp.Path(), nil) - err = resp.Error + resp, err := cm.Get(rp.Path()) if err != nil { return nil, err } - return bytes.NewReader(resp.Value), nil + return bytes.NewReader(resp), nil +} + +func (rc remoteConfigProvider) WatchChannel(rp viper.RemoteProvider) (<-chan *viper.RemoteResponse, chan bool) { + cm, err := getConfigManager(rp) + if err != nil { + return nil, nil + } + quit := make(chan bool) + quitwc := make(chan bool) + viperResponsCh := make(chan *viper.RemoteResponse) + cryptoResponseCh := cm.Watch(rp.Path(), quit) + // need this function to convert the Channel response form crypt.Response to viper.Response + go func(cr <-chan *crypt.Response, vr chan<- *viper.RemoteResponse, quitwc <-chan bool, quit chan<- bool) { + for { + select { + case <-quitwc: + quit <- true + return + case resp := <-cr: + vr <- &viper.RemoteResponse{ + Error: resp.Error, + Value: resp.Value, + } + + } + + } + }(cryptoResponseCh, viperResponsCh, quitwc, quit) + + return viperResponsCh, quitwc } func getConfigManager(rp viper.RemoteProvider) (crypt.ConfigManager, error) { - var cm crypt.ConfigManager var err error @@ -69,7 +98,6 @@ func getConfigManager(rp viper.RemoteProvider) (crypt.ConfigManager, error) { return nil, err } return cm, nil - } func init() { diff --git a/util.go b/util.go index 34062e2..975d494 100644 --- a/util.go +++ b/util.go @@ -11,22 +11,16 @@ package viper import ( - "bytes" - "encoding/json" "fmt" - "io" "os" "path/filepath" "runtime" "strings" "unicode" - "github.com/hashicorp/hcl" - "github.com/magiconair/properties" - toml "github.com/pelletier/go-toml" + "github.com/spf13/afero" "github.com/spf13/cast" jww "github.com/spf13/jwalterweatherman" - "gopkg.in/yaml.v2" ) // ConfigParseError denotes failing to parse configuration file. @@ -138,8 +132,8 @@ func absPathify(inPath string) string { } // Check if File / Directory Exists -func exists(path string) (bool, error) { - _, err := v.fs.Stat(path) +func exists(fs afero.Fs, path string) (bool, error) { + _, err := fs.Stat(path) if err == nil { return true, nil } @@ -213,6 +207,17 @@ func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType s return nil } +func userHomeDir() string { + if runtime.GOOS == "windows" { + home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") + if home == "" { + home = os.Getenv("USERPROFILE") + } + return home + } + return os.Getenv("HOME") +} + func safeMul(a, b uint) uint { c := a * b if a > 1 && b > 1 && c/b != a { diff --git a/util_test.go b/util_test.go index 5949e09..0af80bb 100644 --- a/util_test.go +++ b/util_test.go @@ -16,7 +16,6 @@ import ( ) func TestCopyAndInsensitiviseMap(t *testing.T) { - var ( given = map[string]interface{}{ "Foo": 32, diff --git a/viper.go b/viper.go index 2603c78..907a102 100644 --- a/viper.go +++ b/viper.go @@ -21,6 +21,8 @@ package viper import ( "bytes" + "encoding/csv" + "encoding/json" "fmt" "io" "log" @@ -30,16 +32,37 @@ import ( "strings" "time" + yaml "gopkg.in/yaml.v2" + "github.com/fsnotify/fsnotify" + "github.com/hashicorp/hcl" + "github.com/hashicorp/hcl/hcl/printer" + "github.com/magiconair/properties" "github.com/mitchellh/mapstructure" + toml "github.com/pelletier/go-toml" "github.com/spf13/afero" "github.com/spf13/cast" jww "github.com/spf13/jwalterweatherman" "github.com/spf13/pflag" ) +// ConfigMarshalError happens when failing to marshal the configuration. +type ConfigMarshalError struct { + err error +} + +// Error returns the formatted configuration error. +func (e ConfigMarshalError) Error() string { + return fmt.Sprintf("While marshaling config: %s", e.err.Error()) +} + var v *Viper +type RemoteResponse struct { + Value []byte + Error error +} + func init() { v = New() } @@ -47,6 +70,7 @@ func init() { type remoteConfigFactory interface { Get(rp RemoteProvider) (io.Reader, error) Watch(rp RemoteProvider) (io.Reader, error) + WatchChannel(rp RemoteProvider) (<-chan *RemoteResponse, chan bool) } // RemoteConfig is optional, see the remote package @@ -62,8 +86,7 @@ func (str UnsupportedConfigError) Error() string { } // UnsupportedRemoteProviderError denotes encountering an unsupported remote -// provider. Currently only etcd and Consul are -// supported. +// provider. Currently only etcd and Consul are supported. type UnsupportedRemoteProviderError string // Error returns the formatted remote provider error. @@ -156,6 +179,10 @@ type Viper struct { aliases map[string]string typeByDefValue bool + // Store read properties on the object so that we can write back in order with comments. + // This will only be used if the configuration read is a properties file. + properties *properties.Properties + onConfigChange func(fsnotify.Event) } @@ -182,7 +209,7 @@ func New() *Viper { // can use it in their testing as well. func Reset() { v = New() - SupportedExts = []string{"json", "toml", "yaml", "yml", "hcl"} + SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl"} SupportedRemoteProviders = []string{"etcd", "consul"} } @@ -276,8 +303,8 @@ func (v *Viper) WatchConfig() { }() } -// SetConfigFile explicitly defines the path, name and extension of the config file -// Viper will use this and not check any of the config paths +// SetConfigFile explicitly defines the path, name and extension of the config file. +// Viper will use this and not check any of the config paths. func SetConfigFile(in string) { v.SetConfigFile(in) } func (v *Viper) SetConfigFile(in string) { if in != "" { @@ -286,8 +313,8 @@ func (v *Viper) SetConfigFile(in string) { } // SetEnvPrefix defines a prefix that ENVIRONMENT variables will use. -// E.g. if your prefix is "spf", the env registry -// will look for env. variables that start with "SPF_" +// E.g. if your prefix is "spf", the env registry will look for env +// variables that start with "SPF_". func SetEnvPrefix(in string) { v.SetEnvPrefix(in) } func (v *Viper) SetEnvPrefix(in string) { if in != "" { @@ -305,11 +332,11 @@ func (v *Viper) mergeWithEnvPrefix(in string) string { // TODO: should getEnv logic be moved into find(). Can generalize the use of // rewriting keys many things, Ex: Get('someKey') -> some_key -// (cammel case to snake case for JSON keys perhaps) +// (camel case to snake case for JSON keys perhaps) // getEnv is a wrapper around os.Getenv which replaces characters in the original -// key. This allows env vars which have different keys then the config object -// keys +// key. This allows env vars which have different keys than the config object +// keys. func (v *Viper) getEnv(key string) string { if v.envKeyReplacer != nil { key = v.envKeyReplacer.Replace(key) @@ -317,7 +344,7 @@ func (v *Viper) getEnv(key string) string { return os.Getenv(key) } -// ConfigFileUsed returns the file used to populate the config registry +// ConfigFileUsed returns the file used to populate the config registry. func ConfigFileUsed() string { return v.ConfigFileUsed() } func (v *Viper) ConfigFileUsed() string { return v.configFile } @@ -590,32 +617,33 @@ func (v *Viper) Get(key string) interface{} { return nil } - valType := val if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. + valType := val path := strings.Split(lcaseKey, v.keyDelim) defVal := v.searchMap(v.defaults, path) if defVal != nil { valType = defVal } + + switch valType.(type) { + case bool: + return cast.ToBool(val) + case string: + return cast.ToString(val) + case int64, int32, int16, int8, int: + return cast.ToInt(val) + case float64, float32: + return cast.ToFloat64(val) + case time.Time: + return cast.ToTime(val) + case time.Duration: + return cast.ToDuration(val) + case []string: + return cast.ToStringSlice(val) + } } - switch valType.(type) { - case bool: - return cast.ToBool(val) - case string: - return cast.ToString(val) - case int64, int32, int16, int8, int: - return cast.ToInt(val) - case float64, float32: - return cast.ToFloat64(val) - case time.Time: - return cast.ToTime(val) - case time.Duration: - return cast.ToDuration(val) - case []string: - return cast.ToStringSlice(val) - } return val } @@ -654,6 +682,12 @@ func (v *Viper) GetInt(key string) int { return cast.ToInt(v.Get(key)) } +// GetInt32 returns the value associated with the key as an integer. +func GetInt32(key string) int32 { return v.GetInt32(key) } +func (v *Viper) GetInt32(key string) int32 { + return cast.ToInt32(v.Get(key)) +} + // GetInt64 returns the value associated with the key as an integer. func GetInt64(key string) int64 { return v.GetInt64(key) } func (v *Viper) GetInt64(key string) int64 { @@ -713,7 +747,15 @@ func (v *Viper) GetSizeInBytes(key string) uint { // UnmarshalKey takes a single key and unmarshals it into a Struct. func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) } func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error { - return mapstructure.Decode(v.Get(key), rawVal) + err := decode(v.Get(key), defaultDecoderConfig(rawVal)) + + if err != nil { + return err + } + + v.insensitiviseMaps() + + return nil } // Unmarshal unmarshals the config into a Struct. Make sure that the tags @@ -732,13 +774,16 @@ func (v *Viper) Unmarshal(rawVal interface{}) error { } // defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot -// of time.Duration values +// of time.Duration values & string slices func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig { return &mapstructure.DecoderConfig{ Metadata: nil, Result: output, WeaklyTypedInput: true, - DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ), } } @@ -799,7 +844,7 @@ func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { } // BindFlagValue binds a specific key to a FlagValue. -// Example(where serverCmd is a Cobra instance): +// Example (where serverCmd is a Cobra instance): // // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") // Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) @@ -880,7 +925,9 @@ func (v *Viper) find(lcaseKey string) interface{} { return cast.ToBool(flag.ValueString()) case "stringSlice": s := strings.TrimPrefix(flag.ValueString(), "[") - return strings.TrimSuffix(s, "]") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return res default: return flag.ValueString() } @@ -947,7 +994,9 @@ func (v *Viper) find(lcaseKey string) interface{} { return cast.ToBool(flag.ValueString()) case "stringSlice": s := strings.TrimPrefix(flag.ValueString(), "[") - return strings.TrimSuffix(s, "]") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return res default: return flag.ValueString() } @@ -957,6 +1006,15 @@ func (v *Viper) find(lcaseKey string) interface{} { return nil } +func readAsCSV(val string) ([]string, error) { + if val == "" { + return []string{}, nil + } + stringReader := strings.NewReader(val) + csvReader := csv.NewReader(stringReader) + return csvReader.Read() +} + // IsSet checks to see if the key has been set in any of the data locations. // IsSet is case-insensitive for a key. func IsSet(key string) bool { return v.IsSet(key) } @@ -1088,14 +1146,21 @@ func (v *Viper) ReadInConfig() error { return UnsupportedConfigError(v.getConfigType()) } + jww.DEBUG.Println("Reading file: ", filename) file, err := afero.ReadFile(v.fs, filename) if err != nil { return err } - v.config = make(map[string]interface{}) + config := make(map[string]interface{}) - return v.unmarshalReader(bytes.NewReader(file), v.config) + err = v.unmarshalReader(bytes.NewReader(file), config) + if err != nil { + return err + } + + v.config = config + return nil } // MergeInConfig merges a new configuration with an existing config. @@ -1141,6 +1206,195 @@ func (v *Viper) MergeConfig(in io.Reader) error { return nil } +// WriteConfig writes the current configuration to a file. +func WriteConfig() error { return v.WriteConfig() } +func (v *Viper) WriteConfig() error { + filename, err := v.getConfigFile() + if err != nil { + return err + } + return v.writeConfig(filename, true) +} + +// SafeWriteConfig writes current configuration to file only if the file does not exist. +func SafeWriteConfig() error { return v.SafeWriteConfig() } +func (v *Viper) SafeWriteConfig() error { + filename, err := v.getConfigFile() + if err != nil { + return err + } + return v.writeConfig(filename, false) +} + +// WriteConfigAs writes current configuration to a given filename. +func WriteConfigAs(filename string) error { return v.WriteConfigAs(filename) } +func (v *Viper) WriteConfigAs(filename string) error { + return v.writeConfig(filename, true) +} + +// SafeWriteConfigAs writes current configuration to a given filename if it does not exist. +func SafeWriteConfigAs(filename string) error { return v.SafeWriteConfigAs(filename) } +func (v *Viper) SafeWriteConfigAs(filename string) error { + return v.writeConfig(filename, false) +} + +func writeConfig(filename string, force bool) error { return v.writeConfig(filename, force) } +func (v *Viper) writeConfig(filename string, force bool) error { + jww.INFO.Println("Attempting to write configuration to file.") + ext := filepath.Ext(filename) + if len(ext) <= 1 { + return fmt.Errorf("Filename: %s requires valid extension.", filename) + } + configType := ext[1:] + if !stringInSlice(configType, SupportedExts) { + return UnsupportedConfigError(configType) + } + if v.config == nil { + v.config = make(map[string]interface{}) + } + var flags int + if force == true { + flags = os.O_CREATE | os.O_TRUNC | os.O_WRONLY + } else { + if _, err := os.Stat(filename); os.IsNotExist(err) { + flags = os.O_WRONLY + } else { + return fmt.Errorf("File: %s exists. Use WriteConfig to overwrite.", filename) + } + } + f, err := v.fs.OpenFile(filename, flags, os.FileMode(0644)) + if err != nil { + return err + } + return v.marshalWriter(f, configType) +} + +// Unmarshal a Reader into a map. +// Should probably be an unexported function. +func unmarshalReader(in io.Reader, c map[string]interface{}) error { + return v.unmarshalReader(in, c) +} +func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { + buf := new(bytes.Buffer) + buf.ReadFrom(in) + + switch strings.ToLower(v.getConfigType()) { + case "yaml", "yml": + if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { + return ConfigParseError{err} + } + + case "json": + if err := json.Unmarshal(buf.Bytes(), &c); err != nil { + return ConfigParseError{err} + } + + case "hcl": + obj, err := hcl.Parse(string(buf.Bytes())) + if err != nil { + return ConfigParseError{err} + } + if err = hcl.DecodeObject(&c, obj); err != nil { + return ConfigParseError{err} + } + + case "toml": + tree, err := toml.LoadReader(buf) + if err != nil { + return ConfigParseError{err} + } + tmap := tree.ToMap() + for k, v := range tmap { + c[k] = v + } + + case "properties", "props", "prop": + v.properties = properties.NewProperties() + var err error + if v.properties, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil { + return ConfigParseError{err} + } + for _, key := range v.properties.Keys() { + value, _ := v.properties.Get(key) + // recursively build nested maps + path := strings.Split(key, ".") + lastKey := strings.ToLower(path[len(path)-1]) + deepestMap := deepSearch(c, path[0:len(path)-1]) + // set innermost value + deepestMap[lastKey] = value + } + } + + insensitiviseMap(c) + return nil +} + +// Marshal a map into Writer. +func marshalWriter(f afero.File, configType string) error { + return v.marshalWriter(f, configType) +} +func (v *Viper) marshalWriter(f afero.File, configType string) error { + c := v.AllSettings() + switch configType { + case "json": + b, err := json.MarshalIndent(c, "", " ") + if err != nil { + return ConfigMarshalError{err} + } + _, err = f.WriteString(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + + case "hcl": + b, err := json.Marshal(c) + ast, err := hcl.Parse(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + err = printer.Fprint(f, ast.Node) + if err != nil { + return ConfigMarshalError{err} + } + + case "prop", "props", "properties": + if v.properties == nil { + v.properties = properties.NewProperties() + } + p := v.properties + for _, key := range v.AllKeys() { + _, _, err := p.Set(key, v.GetString(key)) + if err != nil { + return ConfigMarshalError{err} + } + } + _, err := p.WriteComment(f, "#", properties.UTF8) + if err != nil { + return ConfigMarshalError{err} + } + + case "toml": + t, err := toml.TreeFromMap(c) + if err != nil { + return ConfigMarshalError{err} + } + s := t.String() + if _, err := f.WriteString(s); err != nil { + return ConfigMarshalError{err} + } + + case "yaml", "yml": + b, err := yaml.Marshal(c) + if err != nil { + return ConfigMarshalError{err} + } + if _, err = f.WriteString(string(b)); err != nil { + return ConfigMarshalError{err} + } + } + return nil +} + func keyExists(k string, m map[string]interface{}) string { lk := strings.ToLower(k) for mk := range m { @@ -1249,14 +1503,8 @@ func (v *Viper) WatchRemoteConfig() error { return v.watchKeyValueConfig() } -// Unmarshall a Reader into a map. -// Should probably be an unexported function. -func unmarshalReader(in io.Reader, c map[string]interface{}) error { - return v.unmarshalReader(in, c) -} - -func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { - return unmarshallConfigReader(in, c, v.getConfigType()) +func (v *Viper) WatchRemoteConfigOnChannel() error { + return v.watchKeyValueConfigOnChannel() } func (v *Viper) insensitiviseMaps() { @@ -1292,6 +1540,23 @@ func (v *Viper) getRemoteConfig(provider RemoteProvider) (map[string]interface{} return v.kvstore, err } +// Retrieve the first found remote configuration. +func (v *Viper) watchKeyValueConfigOnChannel() error { + for _, rp := range v.remoteProviders { + respc, _ := RemoteConfig.WatchChannel(rp) + //Todo: Add quit channel + go func(rc <-chan *RemoteResponse) { + for { + b := <-rc + reader := bytes.NewReader(b.Value) + v.unmarshalReader(reader, v.kvstore) + } + }(respc) + return nil + } + return RemoteConfigError("No Files Found") +} + // Retrieve the first found remote configuration. func (v *Viper) watchKeyValueConfig() error { for _, rp := range v.remoteProviders { @@ -1461,25 +1726,21 @@ func (v *Viper) getConfigType() string { } func (v *Viper) getConfigFile() (string, error) { - // if explicitly set, then use it - if v.configFile != "" { - return v.configFile, nil + if v.configFile == "" { + cf, err := v.findConfigFile() + if err != nil { + return "", err + } + v.configFile = cf } - - cf, err := v.findConfigFile() - if err != nil { - return "", err - } - - v.configFile = cf - return v.getConfigFile() + return v.configFile, nil } func (v *Viper) searchInPath(in string) (filename string) { jww.DEBUG.Println("Searching for config in ", in) for _, ext := range SupportedExts { jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext)) - if b, _ := exists(filepath.Join(in, v.configName+"."+ext)); b { + if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b { jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext)) return filepath.Join(in, v.configName+"."+ext) } @@ -1491,7 +1752,6 @@ func (v *Viper) searchInPath(in string) (filename string) { // Search all configPaths for any config file. // Returns the first path that exists (and is a config file). func (v *Viper) findConfigFile() (string, error) { - jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths { diff --git a/viper_test.go b/viper_test.go index 6b9754a..43fd41f 100644 --- a/viper_test.go +++ b/viper_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/spf13/afero" "github.com/spf13/cast" "github.com/spf13/pflag" @@ -269,7 +270,7 @@ func TestDefault(t *testing.T) { assert.Equal(t, "leather", Get("clothing.jacket")) } -func TestUnmarshalling(t *testing.T) { +func TestUnmarshaling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample) @@ -424,7 +425,7 @@ func TestAutoEnvWithPrefix(t *testing.T) { assert.Equal(t, "13", Get("bar")) } -func TestSetEnvReplacer(t *testing.T) { +func TestSetEnvKeyReplacer(t *testing.T) { Reset() AutomaticEnv() @@ -545,6 +546,44 @@ func TestBindPFlags(t *testing.T) { } +func TestBindPFlagsStringSlice(t *testing.T) { + for _, testValue := range []struct { + Expected []string + Value string + }{ + {[]string{}, ""}, + {[]string{"jeden"}, "jeden"}, + {[]string{"dwa", "trzy"}, "dwa,trzy"}, + {[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} { + + for _, changed := range []bool{true, false} { + v := New() // create independent Viper object + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.StringSlice("stringslice", testValue.Expected, "test") + flagSet.Visit(func(f *pflag.Flag) { + if len(testValue.Value) > 0 { + f.Value.Set(testValue.Value) + f.Changed = changed + } + }) + + err := v.BindPFlags(flagSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + type TestStr struct { + StringSlice []string + } + val := &TestStr{} + if err := v.Unmarshal(val); err != nil { + t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) + } + assert.Equal(t, testValue.Expected, val.StringSlice) + } + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString) @@ -816,6 +855,190 @@ func TestSub(t *testing.T) { assert.Equal(t, (*Viper)(nil), subv) } +var hclWriteExpected = []byte(`"foos" = { + "foo" = { + "key" = 1 + } + + "foo" = { + "key" = 2 + } + + "foo" = { + "key" = 3 + } + + "foo" = { + "key" = 4 + } +} + +"id" = "0001" + +"name" = "Cake" + +"ppu" = 0.55 + +"type" = "donut"`) + +func TestWriteConfigHCL(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("hcl") + err := v.ReadConfig(bytes.NewBuffer(hclExample)) + if err != nil { + t.Fatal(err) + } + if err := v.WriteConfigAs("c.hcl"); err != nil { + t.Fatal(err) + } + read, err := afero.ReadFile(fs, "c.hcl") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, hclWriteExpected, read) +} + +var jsonWriteExpected = []byte(`{ + "batters": { + "batter": [ + { + "type": "Regular" + }, + { + "type": "Chocolate" + }, + { + "type": "Blueberry" + }, + { + "type": "Devil's Food" + } + ] + }, + "id": "0001", + "name": "Cake", + "ppu": 0.55, + "type": "donut" +}`) + +func TestWriteConfigJson(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("json") + err := v.ReadConfig(bytes.NewBuffer(jsonExample)) + if err != nil { + t.Fatal(err) + } + if err := v.WriteConfigAs("c.json"); err != nil { + t.Fatal(err) + } + read, err := afero.ReadFile(fs, "c.json") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, jsonWriteExpected, read) +} + +var propertiesWriteExpected = []byte(`p_id = 0001 +p_type = donut +p_name = Cake +p_ppu = 0.55 +p_batters.batter.type = Regular +`) + +func TestWriteConfigProperties(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("properties") + err := v.ReadConfig(bytes.NewBuffer(propertiesExample)) + if err != nil { + t.Fatal(err) + } + if err := v.WriteConfigAs("c.properties"); err != nil { + t.Fatal(err) + } + read, err := afero.ReadFile(fs, "c.properties") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, propertiesWriteExpected, read) +} + +func TestWriteConfigTOML(t *testing.T) { + fs := afero.NewMemMapFs() + v := New() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("toml") + err := v.ReadConfig(bytes.NewBuffer(tomlExample)) + if err != nil { + t.Fatal(err) + } + if err := v.WriteConfigAs("c.toml"); err != nil { + t.Fatal(err) + } + + // The TOML String method does not order the contents. + // Therefore, we must read the generated file and compare the data. + v2 := New() + v2.SetFs(fs) + v2.SetConfigName("c") + v2.SetConfigType("toml") + v2.SetConfigFile("c.toml") + err = v2.ReadInConfig() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, v.GetString("title"), v2.GetString("title")) + assert.Equal(t, v.GetString("owner.bio"), v2.GetString("owner.bio")) + assert.Equal(t, v.GetString("owner.dob"), v2.GetString("owner.dob")) + assert.Equal(t, v.GetString("owner.organization"), v2.GetString("owner.organization")) +} + +var yamlWriteExpected = []byte(`age: 35 +beard: true +clothing: + jacket: leather + pants: + size: large + trousers: denim +eyes: brown +hacker: true +hobbies: +- skateboarding +- snowboarding +- go +name: steve +`) + +func TestWriteConfigYAML(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("yaml") + err := v.ReadConfig(bytes.NewBuffer(yamlExample)) + if err != nil { + t.Fatal(err) + } + if err := v.WriteConfigAs("c.yaml"); err != nil { + t.Fatal(err) + } + read, err := afero.ReadFile(fs, "c.yaml") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, yamlWriteExpected, read) +} + var yamlMergeExampleTgt = []byte(` hello: pop: 37890 @@ -852,6 +1075,10 @@ func TestMergeConfig(t *testing.T) { t.Fatalf("lagrenum != 765432101234567, = %d", pop) } + if pop := v.GetInt32("hello.pop"); pop != int32(37890) { + t.Fatalf("pop != 37890, = %d", pop) + } + if pop := v.GetInt64("hello.lagrenum"); pop != int64(765432101234567) { t.Fatalf("int64 lagrenum != 765432101234567, = %d", pop) } @@ -876,6 +1103,10 @@ func TestMergeConfig(t *testing.T) { t.Fatalf("lagrenum != 7654321001234567, = %d", pop) } + if pop := v.GetInt32("hello.pop"); pop != int32(45000) { + t.Fatalf("pop != 45000, = %d", pop) + } + if pop := v.GetInt64("hello.lagrenum"); pop != int64(7654321001234567) { t.Fatalf("int64 lagrenum != 7654321001234567, = %d", pop) } @@ -1109,6 +1340,35 @@ func TestCaseInsensitiveSet(t *testing.T) { } } +func TestParseNested(t *testing.T) { + type duration struct { + Delay time.Duration + } + + type item struct { + Name string + Delay time.Duration + Nested duration + } + + config := `[[parent]] + delay="100ms" + [parent.nested] + delay="200ms" +` + initConfig("toml", config) + + var items []item + err := v.UnmarshalKey("parent", &items) + if err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + + assert.Equal(t, 1, len(items)) + assert.Equal(t, 100*time.Millisecond, items[0].Delay) + assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay) +} + func doTestCaseInsensitive(t *testing.T, typ, config string) { initConfig(typ, config) Set("RfD", true)