Refactor with WaitGroup and check channel is open

Signed-off-by: Xavier Coulon <xcoulon@redhat.com>
This commit is contained in:
Xavier Coulon 2018-05-24 10:09:29 +02:00
parent e0f7631cf3
commit 242f4890f5
2 changed files with 24 additions and 16 deletions

View file

@ -261,8 +261,8 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
func WatchConfig() { v.WatchConfig() } func WatchConfig() { v.WatchConfig() }
func (v *Viper) WatchConfig() { func (v *Viper) WatchConfig() {
wg := sync.WaitGroup{} initWG := sync.WaitGroup{}
wg.Add(1) initWG.Add(1)
go func() { go func() {
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
@ -272,7 +272,7 @@ func (v *Viper) WatchConfig() {
// we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way // we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way
filename, err := v.getConfigFile() filename, err := v.getConfigFile()
if err != nil { if err != nil {
log.Println("error:", err) log.Printf("error: %v\n", err)
return return
} }
@ -280,12 +280,16 @@ func (v *Viper) WatchConfig() {
configDir, _ := filepath.Split(configFile) configDir, _ := filepath.Split(configFile)
realConfigFile, _ := filepath.EvalSymlinks(filename) realConfigFile, _ := filepath.EvalSymlinks(filename)
done := make(chan bool) eventsWG := sync.WaitGroup{}
eventsWG.Add(1)
go func() { go func() {
loop:
for { for {
select { select {
case event := <-watcher.Events: case event, ok := <-watcher.Events:
if !ok { // 'Events' channel is closed
eventsWG.Done()
return
}
currentConfigFile, _ := filepath.EvalSymlinks(filename) currentConfigFile, _ := filepath.EvalSymlinks(filename)
// we only care about the config file with the following cases: // we only care about the config file with the following cases:
// 1 - if the config file was modified or created // 1 - if the config file was modified or created
@ -296,28 +300,31 @@ func (v *Viper) WatchConfig() {
realConfigFile = currentConfigFile realConfigFile = currentConfigFile
err := v.ReadInConfig() err := v.ReadInConfig()
if err != nil { if err != nil {
log.Println("error reading file:", err.Error()) log.Printf("error reading file: %v\n", err)
} }
if v.onConfigChange != nil { if v.onConfigChange != nil {
v.onConfigChange(event) v.onConfigChange(event)
} }
} else if filepath.Clean(event.Name) == configFile && } else if filepath.Clean(event.Name) == configFile &&
event.Op&fsnotify.Remove == fsnotify.Remove { event.Op&fsnotify.Remove == fsnotify.Remove {
done <- true eventsWG.Done()
break loop return
} }
case err := <-watcher.Errors: case err, ok := <-watcher.Errors:
log.Printf("watcher error: %v\n", err) if ok { // 'Errors' channel is not closed
log.Printf("watcher error: %v\n", err)
}
eventsWG.Done()
return
} }
} }
}() }()
watcher.Add(configDir) watcher.Add(configDir)
wg.Done() // done initalizing the watch in this go routine, so the parent routine can move on... initWG.Done() // done initalizing the watch in this go routine, so the parent routine can move on...
<-done // block until the watched file is removed... eventsWG.Wait() // now, wait for event loop to end in this go-routine...
}() }()
// make sure that the go routine above fully started before returning initWG.Wait() // make sure that the go routine above fully ended before returning
wg.Wait()
} }
// SetConfigFile explicitly defines the path, name and extension of the config file. // SetConfigFile explicitly defines the path, name and extension of the config file.

View file

@ -1419,9 +1419,11 @@ func newViperWithSymlinkedConfigFile(t *testing.T) (*Viper, string, string, func
} }
func TestWatchFile(t *testing.T) { func TestWatchFile(t *testing.T) {
t.Run("file content changed", func(t *testing.T) { t.Run("file content changed", func(t *testing.T) {
// given a `config.yaml` file being watched // given a `config.yaml` file being watched
v, configFile, cleanup := newViperWithConfigFile(t) v, configFile, cleanup := newViperWithConfigFile(t)
fmt.Printf("test config file: %s\n", configFile)
defer cleanup() defer cleanup()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
v.WatchConfig() v.WatchConfig()
@ -1466,7 +1468,6 @@ func TestWatchFile(t *testing.T) {
// then // then
require.Nil(t, err) require.Nil(t, err)
assert.Equal(t, "baz", v.Get("foo")) assert.Equal(t, "baz", v.Get("foo"))
}) })
} }