diff --git a/bash_completions.go b/bash_completions.go index 091aaff1..db350912 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -534,6 +534,7 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) { // prepareCustomAnnotationsForFlags setup annotations for go completions for registered flags func prepareCustomAnnotationsForFlags(cmd *Command) { + cmd.initializeCompletionStorage() cmd.flagCompletionMutex.RLock() defer cmd.flagCompletionMutex.RUnlock() for flag := range cmd.flagCompletionFunctions { diff --git a/completions.go b/completions.go index 27110e07..c8b6094d 100644 --- a/completions.go +++ b/completions.go @@ -18,8 +18,10 @@ import ( "fmt" "os" "strings" + "sync" "github.com/spf13/pflag" + flag "github.com/spf13/pflag" ) const ( @@ -134,9 +136,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman if flag == nil { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName) } + // Ensure none of our relevant fields are nil. + c.initializeCompletionStorage() + c.flagCompletionMutex.Lock() defer c.flagCompletionMutex.Unlock() + // And attempt to bind the completion. if _, exists := c.flagCompletionFunctions[flag]; exists { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName) } @@ -144,6 +150,20 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman return nil } +// initializeCompletionStorage is (and should be) called in all +// functions that make use of the command's flag completion functions. +func (c *Command) initializeCompletionStorage() { + if c.flagCompletionMutex == nil { + c.flagCompletionMutex = new(sync.RWMutex) + } + + var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) + + if c.flagCompletionFunctions == nil { + c.flagCompletionFunctions = make(map[*flag.Flag]completionFn, 0) + } +} + // Returns a string listing the different directive enabled in the specified parameter func (d ShellCompDirective) string() string { var directives []string @@ -478,6 +498,8 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Find the completion function for the flag or command var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) if flag != nil && flagCompletion { + c.initializeCompletionStorage() + finalCmd.flagCompletionMutex.RLock() completionFn = finalCmd.flagCompletionFunctions[flag] finalCmd.flagCompletionMutex.RUnlock()