Improve API to get flag completion function (#2063)

The new API is simpler and matches the `c.RegisterFlagCompletionFunc()`
API.  By removing the global function `GetFlagCompletion()` we are more
future proof if we ever move from a global map of flag completion
functions to something associated with the command.

The commit also makes this API work with persistent flags by using
`c.Flag(flagName)` instead of `c.Flags().Lookup(flagName)`.

The commit also adds unit tests.

Signed-off-by: Marc Khouzam <marc.khouzam@gmail.com>
This commit is contained in:
Marc Khouzam 2023-11-02 11:23:08 -04:00 committed by GitHub
parent 890302a35f
commit a0a6ae020b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 12 deletions

View file

@ -145,8 +145,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
return nil return nil
} }
// GetFlagCompletion returns the completion function for the given flag, if available. // GetFlagCompletionFunc returns the completion function for the given flag of the command, if available.
func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { func (c *Command) GetFlagCompletionFunc(flagName string) (func(*Command, []string, string) ([]string, ShellCompDirective), bool) {
flag := c.Flag(flagName)
if flag == nil {
return nil, false
}
flagCompletionMutex.RLock() flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock() defer flagCompletionMutex.RUnlock()
@ -154,16 +159,6 @@ func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toCo
return completionFunc, exists return completionFunc, exists
} }
// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available.
func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
flag := c.Flags().Lookup(flagName)
if flag == nil {
return nil, false
}
return GetFlagCompletion(flag)
}
// Returns a string listing the different directive enabled in the specified parameter // Returns a string listing the different directive enabled in the specified parameter
func (d ShellCompDirective) string() string { func (d ShellCompDirective) string() string {
var directives []string var directives []string

View file

@ -3427,3 +3427,93 @@ Completion ended with directive: ShellCompDirectiveNoFileComp
}) })
} }
} }
func TestGetFlagCompletion(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
rootCmd.Flags().String("rootflag", "", "root flag")
_ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"rootvalue"}, ShellCompDirectiveKeepOrder
})
rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag")
_ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"persistentvalue"}, ShellCompDirectiveDefault
})
childCmd := &Command{Use: "child", Run: emptyRun}
childCmd.Flags().String("childflag", "", "child flag")
_ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace
})
rootCmd.AddCommand(childCmd)
testcases := []struct {
desc string
cmd *Command
flagName string
exists bool
comps []string
directive ShellCompDirective
}{
{
desc: "get flag completion function for command",
cmd: rootCmd,
flagName: "rootflag",
exists: true,
comps: []string{"rootvalue"},
directive: ShellCompDirectiveKeepOrder,
},
{
desc: "get persistent flag completion function for command",
cmd: rootCmd,
flagName: "persistentflag",
exists: true,
comps: []string{"persistentvalue"},
directive: ShellCompDirectiveDefault,
},
{
desc: "get flag completion function for child command",
cmd: childCmd,
flagName: "childflag",
exists: true,
comps: []string{"childvalue"},
directive: ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace,
},
{
desc: "get persistent flag completion function for child command",
cmd: childCmd,
flagName: "persistentflag",
exists: true,
comps: []string{"persistentvalue"},
directive: ShellCompDirectiveDefault,
},
{
desc: "cannot get flag completion function for local parent flag",
cmd: childCmd,
flagName: "rootflag",
exists: false,
},
}
for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
compFunc, exists := tc.cmd.GetFlagCompletionFunc(tc.flagName)
if tc.exists != exists {
t.Errorf("Unexpected result looking for flag completion function")
}
if exists {
comps, directive := compFunc(tc.cmd, []string{}, "")
if strings.Join(tc.comps, " ") != strings.Join(comps, " ") {
t.Errorf("Unexpected completions %q", comps)
}
if tc.directive != directive {
t.Errorf("Unexpected directive %q", directive)
}
}
})
}
}