mirror of
https://github.com/spf13/viper
synced 2024-12-22 11:37:02 +00:00
WatchConfig and Kubernetes (#284)
Support override of symlink to config file Include tests for WatchConfig of regular files, as well as config file which links to a folder which is itself a link to another folder in the same "watch dir" (the way Kubernetes exposes config files from ConfigMaps mounted on a volume in a Pod) Also: - Add synchronization with WaitGroup to ensure that the WatchConfig is properly started before returning - Remove the watcher when the Config file is removed. Fixes #284 Signed-off-by: Xavier Coulon <xcoulon@redhat.com>
This commit is contained in:
parent
aafc9e6bc7
commit
e0f7631cf3
3 changed files with 137 additions and 12 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -22,3 +22,8 @@ _testmain.go
|
|||
*.exe
|
||||
*.test
|
||||
*.bench
|
||||
|
||||
.vscode
|
||||
|
||||
# exclude dependencies in the `/vendor` folder
|
||||
vendor
|
||||
|
|
41
viper.go
41
viper.go
|
@ -30,6 +30,7 @@ import (
|
|||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
@ -260,13 +261,14 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
|
|||
|
||||
func WatchConfig() { v.WatchConfig() }
|
||||
func (v *Viper) WatchConfig() {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way
|
||||
filename, err := v.getConfigFile()
|
||||
if err != nil {
|
||||
|
@ -276,31 +278,46 @@ func (v *Viper) WatchConfig() {
|
|||
|
||||
configFile := filepath.Clean(filename)
|
||||
configDir, _ := filepath.Split(configFile)
|
||||
realConfigFile, _ := filepath.EvalSymlinks(filename)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-watcher.Events:
|
||||
// we only care about the config file
|
||||
if filepath.Clean(event.Name) == configFile {
|
||||
if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create {
|
||||
currentConfigFile, _ := filepath.EvalSymlinks(filename)
|
||||
// we only care about the config file with the following cases:
|
||||
// 1 - if the config file was modified or created
|
||||
// 2 - if the real path to the config file changed (eg: k8s ConfigMap replacement)
|
||||
if (filepath.Clean(event.Name) == configFile &&
|
||||
(event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create)) ||
|
||||
(currentConfigFile != "" && currentConfigFile != realConfigFile) {
|
||||
realConfigFile = currentConfigFile
|
||||
err := v.ReadInConfig()
|
||||
if err != nil {
|
||||
log.Println("error:", err)
|
||||
log.Println("error reading file:", err.Error())
|
||||
}
|
||||
if v.onConfigChange != nil {
|
||||
v.onConfigChange(event)
|
||||
}
|
||||
} else if filepath.Clean(event.Name) == configFile &&
|
||||
event.Op&fsnotify.Remove == fsnotify.Remove {
|
||||
done <- true
|
||||
break loop
|
||||
}
|
||||
case err := <-watcher.Errors:
|
||||
log.Println("error:", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
watcher.Add(configDir)
|
||||
<-done
|
||||
case err := <-watcher.Errors:
|
||||
log.Printf("watcher error: %v\n", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
watcher.Add(configDir)
|
||||
wg.Done() // done initalizing the watch in this go routine, so the parent routine can move on...
|
||||
<-done // block until the watched file is removed...
|
||||
}()
|
||||
// make sure that the go routine above fully started before returning
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// SetConfigFile explicitly defines the path, name and extension of the config file.
|
||||
|
|
103
viper_test.go
103
viper_test.go
|
@ -11,18 +11,23 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var yamlExample = []byte(`Hacker: true
|
||||
|
@ -1368,6 +1373,104 @@ func doTestCaseInsensitive(t *testing.T, typ, config string) {
|
|||
|
||||
}
|
||||
|
||||
func newViperWithConfigFile(t *testing.T) (*Viper, string, func()) {
|
||||
watchDir, err := ioutil.TempDir("", "")
|
||||
require.Nil(t, err)
|
||||
configFile := path.Join(watchDir, "config.yaml")
|
||||
err = ioutil.WriteFile(configFile, []byte("foo: bar\n"), 0640)
|
||||
require.Nil(t, err)
|
||||
cleanup := func() {
|
||||
os.RemoveAll(watchDir)
|
||||
}
|
||||
v := New()
|
||||
v.SetConfigFile(configFile)
|
||||
err = v.ReadInConfig()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "bar", v.Get("foo"))
|
||||
return v, configFile, cleanup
|
||||
}
|
||||
|
||||
func newViperWithSymlinkedConfigFile(t *testing.T) (*Viper, string, string, func()) {
|
||||
watchDir, err := ioutil.TempDir("", "")
|
||||
require.Nil(t, err)
|
||||
dataDir1 := path.Join(watchDir, "data1")
|
||||
err = os.Mkdir(dataDir1, 0777)
|
||||
require.Nil(t, err)
|
||||
realConfigFile := path.Join(dataDir1, "config.yaml")
|
||||
t.Logf("Real config file location: %s\n", realConfigFile)
|
||||
err = ioutil.WriteFile(realConfigFile, []byte("foo: bar\n"), 0640)
|
||||
require.Nil(t, err)
|
||||
cleanup := func() {
|
||||
os.RemoveAll(watchDir)
|
||||
}
|
||||
// now, symlink the tm `data1` dir to `data` in the baseDir
|
||||
os.Symlink(dataDir1, path.Join(watchDir, "data"))
|
||||
// and link the `<watchdir>/datadir1/config.yaml` to `<watchdir>/config.yaml`
|
||||
configFile := path.Join(watchDir, "config.yaml")
|
||||
os.Symlink(path.Join(watchDir, "data", "config.yaml"), configFile)
|
||||
fmt.Printf("Config file location: %s\n", path.Join(watchDir, "config.yaml"))
|
||||
// init Viper
|
||||
v := New()
|
||||
v.SetConfigFile(configFile)
|
||||
err = v.ReadInConfig()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "bar", v.Get("foo"))
|
||||
return v, watchDir, configFile, cleanup
|
||||
}
|
||||
|
||||
func TestWatchFile(t *testing.T) {
|
||||
t.Run("file content changed", func(t *testing.T) {
|
||||
// given a `config.yaml` file being watched
|
||||
v, configFile, cleanup := newViperWithConfigFile(t)
|
||||
defer cleanup()
|
||||
wg := sync.WaitGroup{}
|
||||
v.WatchConfig()
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
t.Logf("config file changed")
|
||||
wg.Done()
|
||||
})
|
||||
wg.Add(1)
|
||||
// when overwriting the file and waiting for the custom change notification handler to be triggered
|
||||
err := ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0640)
|
||||
wg.Wait()
|
||||
// then the config value should have changed
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "baz", v.Get("foo"))
|
||||
})
|
||||
|
||||
t.Run("link to real file changed (à la Kubernetes)", func(t *testing.T) {
|
||||
// skip if not executed on Linux
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("Skipping test as symlink replacements don't work on non-linux environment...")
|
||||
}
|
||||
v, watchDir, _, _ := newViperWithSymlinkedConfigFile(t)
|
||||
// defer cleanup()
|
||||
wg := sync.WaitGroup{}
|
||||
v.WatchConfig()
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
t.Logf("config file changed")
|
||||
wg.Done()
|
||||
})
|
||||
wg.Add(1)
|
||||
// when link to another `config.yaml` file
|
||||
dataDir2 := path.Join(watchDir, "data2")
|
||||
err := os.Mkdir(dataDir2, 0777)
|
||||
require.Nil(t, err)
|
||||
configFile2 := path.Join(dataDir2, "config.yaml")
|
||||
err = ioutil.WriteFile(configFile2, []byte("foo: baz\n"), 0640)
|
||||
require.Nil(t, err)
|
||||
// change the symlink using the `ln -sfn` command
|
||||
err = exec.Command("ln", "-sfn", dataDir2, path.Join(watchDir, "data")).Run()
|
||||
require.Nil(t, err)
|
||||
wg.Wait()
|
||||
// then
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "baz", v.Get("foo"))
|
||||
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkGetBool(b *testing.B) {
|
||||
key := "BenchmarkGetBool"
|
||||
v = New()
|
||||
|
|
Loading…
Reference in a new issue