diff --git a/completions.go b/completions.go index 8fccdaf..0862d3f 100644 --- a/completions.go +++ b/completions.go @@ -270,6 +270,14 @@ func (c *Command) initCompleteCmd(args []string) { } } +// SliceValue is a reduced version of [pflag.SliceValue]. It is used to detect +// flags that accept multiple values and therefore can provide completion +// multiple times. +type SliceValue interface { + // GetSlice returns the flag value list as an array of strings. + GetSlice() []string +} + func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDirective, error) { // The last argument, which is not completely typed by the user, // should not be part of the list of arguments @@ -399,10 +407,13 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // If we have not found any required flags, only then can we show regular flags if len(completions) == 0 { doCompleteFlags := func(flag *pflag.Flag) { - if !flag.Changed || + _, acceptsMultiple := flag.Value.(SliceValue) + acceptsMultiple = acceptsMultiple || strings.Contains(flag.Value.Type(), "Slice") || strings.Contains(flag.Value.Type(), "Array") || - strings.HasPrefix(flag.Value.Type(), "stringTo") { + strings.HasPrefix(flag.Value.Type(), "stringTo") + + if !flag.Changed || acceptsMultiple { // If the flag is not already present, or if it can be specified multiple times (Array, Slice, or stringTo) // we suggest it as a completion completions = append(completions, getFlagNameCompletions(flag, toComplete)...) diff --git a/completions_test.go b/completions_test.go index df153fc..a8f378e 100644 --- a/completions_test.go +++ b/completions_test.go @@ -671,6 +671,29 @@ func TestFlagNameCompletionInGoWithDesc(t *testing.T) { } } +// customMultiString is a custom Value type that accepts multiple values, +// but does not include "Slice" or "Array" in its "Type" string. +type customMultiString []string + +var _ SliceValue = (*customMultiString)(nil) + +func (s *customMultiString) String() string { + return fmt.Sprintf("%v", *s) +} + +func (s *customMultiString) Set(v string) error { + *s = append(*s, v) + return nil +} + +func (s *customMultiString) Type() string { + return "multi string" +} + +func (s *customMultiString) GetSlice() []string { + return *s +} + func TestFlagNameCompletionRepeat(t *testing.T) { rootCmd := &Command{ Use: "root", @@ -693,6 +716,8 @@ func TestFlagNameCompletionRepeat(t *testing.T) { sliceFlag := rootCmd.Flags().Lookup("slice") rootCmd.Flags().BoolSliceP("bslice", "b", nil, "bool slice flag") bsliceFlag := rootCmd.Flags().Lookup("bslice") + rootCmd.Flags().VarP(&customMultiString{}, "multi", "m", "multi string flag") + multiFlag := rootCmd.Flags().Lookup("multi") // Test that flag names are not repeated unless they are an array or slice output, err := executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--first", "1", "--") @@ -706,6 +731,7 @@ func TestFlagNameCompletionRepeat(t *testing.T) { "--array", "--bslice", "--help", + "--multi", "--second", "--slice", ":4", @@ -728,6 +754,7 @@ func TestFlagNameCompletionRepeat(t *testing.T) { "--array", "--bslice", "--help", + "--multi", "--slice", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -737,7 +764,7 @@ func TestFlagNameCompletionRepeat(t *testing.T) { } // Test that flag names are not repeated unless they are an array or slice - output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--slice", "1", "--slice=2", "--array", "val", "--bslice", "true", "--") + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--slice", "1", "--slice=2", "--array", "val", "--bslice", "true", "--multi", "val", "--") if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -745,12 +772,14 @@ func TestFlagNameCompletionRepeat(t *testing.T) { sliceFlag.Changed = false arrayFlag.Changed = false bsliceFlag.Changed = false + multiFlag.Changed = false expected = strings.Join([]string{ "--array", "--bslice", "--first", "--help", + "--multi", "--second", "--slice", ":4", @@ -768,6 +797,7 @@ func TestFlagNameCompletionRepeat(t *testing.T) { // Reset the flag for the next command sliceFlag.Changed = false arrayFlag.Changed = false + multiFlag.Changed = false expected = strings.Join([]string{ "--array", @@ -778,6 +808,8 @@ func TestFlagNameCompletionRepeat(t *testing.T) { "-f", "--help", "-h", + "--multi", + "-m", "--second", "-s", "--slice", @@ -797,6 +829,7 @@ func TestFlagNameCompletionRepeat(t *testing.T) { // Reset the flag for the next command sliceFlag.Changed = false arrayFlag.Changed = false + multiFlag.Changed = false expected = strings.Join([]string{ "-a",