feat: add a new API StopWatching() to stop the groutine created by WatchConfig() manually

This commit is contained in:
chowyi 2023-07-31 19:47:29 +08:00
parent e40c5633a5
commit 59e479f6b8
2 changed files with 61 additions and 9 deletions

View file

@ -217,6 +217,7 @@ type Viper struct {
typeByDefValue bool typeByDefValue bool
onConfigChange func(fsnotify.Event) onConfigChange func(fsnotify.Event)
stopWatchingFunc func()
logger Logger logger Logger
@ -432,13 +433,10 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
} }
// WatchConfig starts watching a config file for changes. // WatchConfig starts watching a config file for changes.
// The function returned for stop watching manually. func WatchConfig() { v.WatchConfig() }
func WatchConfig() func() { return v.WatchConfig() }
// WatchConfig starts watching a config file for changes. // WatchConfig starts watching a config file for changes.
// The function returned for stop watching manually. func (v *Viper) WatchConfig() {
func (v *Viper) WatchConfig() func() {
ctx, cancel := context.WithCancel(context.Background())
initWG := sync.WaitGroup{} initWG := sync.WaitGroup{}
initWG.Add(1) initWG.Add(1)
go func() { go func() {
@ -460,6 +458,10 @@ func (v *Viper) WatchConfig() func() {
configDir, _ := filepath.Split(configFile) configDir, _ := filepath.Split(configFile)
realConfigFile, _ := filepath.EvalSymlinks(filename) realConfigFile, _ := filepath.EvalSymlinks(filename)
// init the stopWatchingFunc
watchingCtx, cancel := context.WithCancel(context.Background())
v.stopWatchingFunc = cancel
eventsWG := sync.WaitGroup{} eventsWG := sync.WaitGroup{}
eventsWG.Add(1) eventsWG.Add(1)
go func() { go func() {
@ -496,8 +498,9 @@ func (v *Viper) WatchConfig() func() {
} }
eventsWG.Done() eventsWG.Done()
return return
case <-ctx.Done(): // cancel function called case <-watchingCtx.Done(): // StopWatching function called
watcher.Close() eventsWG.Done()
return
} }
} }
}() }()
@ -506,7 +509,16 @@ func (v *Viper) WatchConfig() func() {
eventsWG.Wait() // now, wait for event loop to end in this go-routine... eventsWG.Wait() // now, wait for event loop to end in this go-routine...
}() }()
initWG.Wait() // make sure that the go routine above fully ended before returning initWG.Wait() // make sure that the go routine above fully ended before returning
return cancel }
// StopWatching stop watching a config file for changes.
func StopWatching() { v.StopWatching() }
// StopWatching stop watching a config file for changes.
func (v *Viper) StopWatching() {
if v.stopWatchingFunc != nil {
v.stopWatchingFunc()
}
} }
// SetConfigFile explicitly defines the path, name and extension of the config file. // SetConfigFile explicitly defines the path, name and extension of the config file.

View file

@ -2545,6 +2545,46 @@ func TestWatchFile(t *testing.T) {
}) })
} }
func TestStopWatching(t *testing.T) {
t.Run(
"file content changed after stop watching", func(t *testing.T) {
// given a `config.yaml` file being watched
v, configFile, cleanup := newViperWithConfigFile(t)
defer cleanup()
_, err := os.Stat(configFile)
require.NoError(t, err)
t.Logf("test config file: %s\n", configFile)
v.WatchConfig()
v.StopWatching()
// overwriting the file after StopWatching called
err = ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0o640)
time.Sleep(time.Second) // wait for file changed event
// then the config value should not be changed
require.Nil(t, err)
assert.Equal(t, "bar", v.Get("foo"))
// watch again
wg := sync.WaitGroup{}
wg.Add(1)
var wgDoneOnce sync.Once // OnConfigChange is called twice on Windows
v.OnConfigChange(
func(in fsnotify.Event) {
t.Logf("config file changed again")
wgDoneOnce.Do(func() { wg.Done() })
},
)
v.WatchConfig()
// overwriting the file after StopWatching and Watch again
err = ioutil.WriteFile(configFile, []byte("foo: qux\n"), 0o640)
wg.Wait()
require.Nil(t, err)
assert.Equal(t, "qux", v.Get("foo"))
},
)
}
func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) { func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) {
flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.String("foo.bar", "cobra_flag", "") flags.String("foo.bar", "cobra_flag", "")