From 5aadb0d2bdef099b034637bdb9e2685327fc2416 Mon Sep 17 00:00:00 2001 From: maxlandon Date: Thu, 2 Nov 2023 17:43:35 +0100 Subject: [PATCH] Fix a logic lookup bug that was kind enough to surface at the good moment. --- completions.go | 22 ++++++++++++++++++---- completions_test.go | 15 ++++++++++++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/completions.go b/completions.go index 52e42662..ba00ad6c 100644 --- a/completions.go +++ b/completions.go @@ -167,13 +167,27 @@ func (c *Command) GetFlagCompletionFunc(flag *pflag.Flag) (func(cmd *Command, ar } // GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available. +// If the flag is not found in the command's local flags, it looks into the persistent flags, which might belong to one of the command's parents. func (c *Command) GetFlagCompletionFuncByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { - flag := c.Flags().Lookup(flagName) - if flag == nil { - return nil, false + // Attempt to find it in the local flags. + if flag := c.Flags().Lookup(flagName); flag != nil { + return c.GetFlagCompletionFunc(flag) } - return c.GetFlagCompletionFunc(flag) + // Or try to find it in the "command-specific" persistent flags. + if flag := c.PersistentFlags().Lookup(flagName); flag != nil { + return c.GetFlagCompletionFunc(flag) + } + + // Else, check all persistent flags belonging to one of the parents. + // This ensures that we won't return the completion function of a + // parent's LOCAL flag. + if flag := c.InheritedFlags().Lookup(flagName); flag != nil { + return c.GetFlagCompletionFunc(flag) + } + + // No flag exists either locally, or as one of the parent persistent flags. + return nil, false } // initializeCompletionStorage is (and should be) called in all diff --git a/completions_test.go b/completions_test.go index 69836d46..451a831d 100644 --- a/completions_test.go +++ b/completions_test.go @@ -3581,21 +3581,30 @@ func TestGetFlagCompletion(t *testing.T) { rootCmd := &Command{Use: "root", Run: emptyRun} rootCmd.Flags().String("rootflag", "", "root flag") - _ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + err := rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"rootvalue"}, ShellCompDirectiveKeepOrder }) + if err != nil { + t.Error(err) + } rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag") - _ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + err = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"persistentvalue"}, ShellCompDirectiveDefault }) + if err != nil { + t.Error(err) + } childCmd := &Command{Use: "child", Run: emptyRun} childCmd.Flags().String("childflag", "", "child flag") - _ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + err = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace }) + if err != nil { + t.Error(err) + } rootCmd.AddCommand(childCmd)