diff --git a/viper.go b/viper.go index b533306..088cd2b 100644 --- a/viper.go +++ b/viper.go @@ -216,7 +216,8 @@ type Viper struct { aliases map[string]string typeByDefValue bool - onConfigChange func(fsnotify.Event) + onConfigChange func(fsnotify.Event) + stopWatchingFunc func() logger Logger @@ -432,13 +433,10 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) { } // WatchConfig starts watching a config file for changes. -// The function returned for stop watching manually. -func WatchConfig() func() { return v.WatchConfig() } +func WatchConfig() { v.WatchConfig() } // WatchConfig starts watching a config file for changes. -// The function returned for stop watching manually. -func (v *Viper) WatchConfig() func() { - ctx, cancel := context.WithCancel(context.Background()) +func (v *Viper) WatchConfig() { initWG := sync.WaitGroup{} initWG.Add(1) go func() { @@ -460,6 +458,10 @@ func (v *Viper) WatchConfig() func() { configDir, _ := filepath.Split(configFile) realConfigFile, _ := filepath.EvalSymlinks(filename) + // init the stopWatchingFunc + watchingCtx, cancel := context.WithCancel(context.Background()) + v.stopWatchingFunc = cancel + eventsWG := sync.WaitGroup{} eventsWG.Add(1) go func() { @@ -496,8 +498,9 @@ func (v *Viper) WatchConfig() func() { } eventsWG.Done() return - case <-ctx.Done(): // cancel function called - watcher.Close() + case <-watchingCtx.Done(): // StopWatching function called + eventsWG.Done() + return } } }() @@ -506,7 +509,16 @@ func (v *Viper) WatchConfig() func() { eventsWG.Wait() // now, wait for event loop to end in this go-routine... }() initWG.Wait() // make sure that the go routine above fully ended before returning - return cancel +} + +// StopWatching stop watching a config file for changes. +func StopWatching() { v.StopWatching() } + +// StopWatching stop watching a config file for changes. +func (v *Viper) StopWatching() { + if v.stopWatchingFunc != nil { + v.stopWatchingFunc() + } } // SetConfigFile explicitly defines the path, name and extension of the config file. diff --git a/viper_test.go b/viper_test.go index e0bfc57..06a1671 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2545,6 +2545,46 @@ func TestWatchFile(t *testing.T) { }) } +func TestStopWatching(t *testing.T) { + t.Run( + "file content changed after stop watching", func(t *testing.T) { + // given a `config.yaml` file being watched + v, configFile, cleanup := newViperWithConfigFile(t) + defer cleanup() + _, err := os.Stat(configFile) + require.NoError(t, err) + t.Logf("test config file: %s\n", configFile) + + v.WatchConfig() + v.StopWatching() + + // overwriting the file after StopWatching called + err = ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0o640) + time.Sleep(time.Second) // wait for file changed event + // then the config value should not be changed + require.Nil(t, err) + assert.Equal(t, "bar", v.Get("foo")) + + // watch again + wg := sync.WaitGroup{} + wg.Add(1) + var wgDoneOnce sync.Once // OnConfigChange is called twice on Windows + v.OnConfigChange( + func(in fsnotify.Event) { + t.Logf("config file changed again") + wgDoneOnce.Do(func() { wg.Done() }) + }, + ) + v.WatchConfig() + // overwriting the file after StopWatching and Watch again + err = ioutil.WriteFile(configFile, []byte("foo: qux\n"), 0o640) + wg.Wait() + require.Nil(t, err) + assert.Equal(t, "qux", v.Get("foo")) + }, + ) +} + func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.String("foo.bar", "cobra_flag", "")