diff --git a/viper.go b/viper.go index 8bc4438..0d0018b 100644 --- a/viper.go +++ b/viper.go @@ -21,6 +21,7 @@ package viper import ( "bytes" + "context" "encoding/csv" "errors" "fmt" @@ -217,7 +218,8 @@ type Viper struct { aliases map[string]string typeByDefValue bool - onConfigChange func(fsnotify.Event) + onConfigChange func(fsnotify.Event) + stopWatchingFunc func() logger *slog.Logger @@ -458,6 +460,10 @@ func (v *Viper) WatchConfig() { configDir, _ := filepath.Split(configFile) realConfigFile, _ := filepath.EvalSymlinks(filename) + // init the stopWatchingFunc + watchingCtx, cancel := context.WithCancel(context.Background()) + v.stopWatchingFunc = cancel + eventsWG := sync.WaitGroup{} eventsWG.Add(1) go func() { @@ -494,6 +500,9 @@ func (v *Viper) WatchConfig() { } eventsWG.Done() return + case <-watchingCtx.Done(): // StopWatching function called + eventsWG.Done() + return } } }() @@ -504,6 +513,16 @@ func (v *Viper) WatchConfig() { initWG.Wait() // make sure that the go routine above fully ended before returning } +// StopWatching stop watching a config file for changes. +func StopWatching() { v.StopWatching() } + +// StopWatching stop watching a config file for changes. +func (v *Viper) StopWatching() { + if v.stopWatchingFunc != nil { + v.stopWatchingFunc() + } +} + // SetConfigFile explicitly defines the path, name and extension of the config file. // Viper will use this and not check any of the config paths. func SetConfigFile(in string) { v.SetConfigFile(in) } diff --git a/viper_test.go b/viper_test.go index f5ac1be..8bd5eb1 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2489,6 +2489,46 @@ func TestWatchFile(t *testing.T) { }) } +func TestStopWatching(t *testing.T) { + t.Run( + "file content changed after stop watching", func(t *testing.T) { + // given a `config.yaml` file being watched + v, configFile, cleanup := newViperWithConfigFile(t) + defer cleanup() + _, err := os.Stat(configFile) + require.NoError(t, err) + t.Logf("test config file: %s\n", configFile) + + v.WatchConfig() + v.StopWatching() + + // overwriting the file after StopWatching called + err = ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0o640) + time.Sleep(time.Second) // wait for file changed event + // then the config value should not be changed + require.Nil(t, err) + assert.Equal(t, "bar", v.Get("foo")) + + // watch again + wg := sync.WaitGroup{} + wg.Add(1) + var wgDoneOnce sync.Once // OnConfigChange is called twice on Windows + v.OnConfigChange( + func(in fsnotify.Event) { + t.Logf("config file changed again") + wgDoneOnce.Do(func() { wg.Done() }) + }, + ) + v.WatchConfig() + // overwriting the file after StopWatching and Watch again + err = ioutil.WriteFile(configFile, []byte("foo: qux\n"), 0o640) + wg.Wait() + require.Nil(t, err) + assert.Equal(t, "qux", v.Get("foo")) + }, + ) +} + func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.String("foo.bar", "cobra_flag", "")