diff --git a/bash_completions.go b/bash_completions.go index ab428ccb..846636d7 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -495,12 +495,14 @@ func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) { func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) { name := flag.Name - format := " local_nonpersistent_flags+=(\"--%s" + format := " local_nonpersistent_flags+=(\"--%[1]s\")\n" if len(flag.NoOptDefVal) == 0 { - format += "=" + format += " local_nonpersistent_flags+=(\"--%[1]s=\")\n" } - format += "\")\n" buf.WriteString(fmt.Sprintf(format, name)) + if len(flag.Shorthand) > 0 { + buf.WriteString(fmt.Sprintf(" local_nonpersistent_flags+=(\"-%s\")\n", flag.Shorthand)) + } } // Setup annotations for go completions for registered flags @@ -535,7 +537,9 @@ func writeFlags(buf *bytes.Buffer, cmd *Command) { if len(flag.Shorthand) > 0 { writeShortFlag(buf, flag, cmd) } - if localNonPersistentFlags.Lookup(flag.Name) != nil { + // localNonPersistentFlags are used to stop the completion of subcommands when one is set + // if TraverseChildren is true we should allow to complete subcommands + if localNonPersistentFlags.Lookup(flag.Name) != nil && !cmd.Root().TraverseChildren { writeLocalNonPersistentFlag(buf, flag) } }) diff --git a/bash_completions_test.go b/bash_completions_test.go index eefa3de0..2c182ba7 100644 --- a/bash_completions_test.go +++ b/bash_completions_test.go @@ -193,6 +193,13 @@ func TestBashCompletions(t *testing.T) { checkOmit(t, output, `two_word_flags+=("--two-w-default")`) checkOmit(t, output, `two_word_flags+=("-T")`) + // check local nonpersistent flag + check(t, output, `local_nonpersistent_flags+=("--two")`) + check(t, output, `local_nonpersistent_flags+=("--two=")`) + check(t, output, `local_nonpersistent_flags+=("-t")`) + check(t, output, `local_nonpersistent_flags+=("--two-w-default")`) + check(t, output, `local_nonpersistent_flags+=("-T")`) + checkOmit(t, output, deprecatedCmd.Name()) // If available, run shellcheck against the script. @@ -235,3 +242,21 @@ func TestBashCompletionDeprecatedFlag(t *testing.T) { t.Errorf("expected completion to not include %q flag: Got %v", flagName, output) } } + +func TestBashCompletionTraverseChildren(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun, TraverseChildren: true} + + c.Flags().StringP("string-flag", "s", "", "string flag") + c.Flags().BoolP("bool-flag", "b", false, "bool flag") + + buf := new(bytes.Buffer) + c.GenBashCompletion(buf) + output := buf.String() + + // check that local nonpersistent flag are not set since we have TraverseChildren set to true + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--string-flag=")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-s")`) + checkOmit(t, output, `local_nonpersistent_flags+=("--bool-flag")`) + checkOmit(t, output, `local_nonpersistent_flags+=("-b")`) +} diff --git a/custom_completions.go b/custom_completions.go index c25c03e4..fc852908 100644 --- a/custom_completions.go +++ b/custom_completions.go @@ -175,8 +175,16 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi toComplete := args[len(args)-1] trimmedArgs := args[:len(args)-1] + var finalCmd *Command + var finalArgs []string + var err error // Find the real command for which completion must be performed - finalCmd, finalArgs, err := c.Root().Find(trimmedArgs) + // check if we need to traverse here to parse local flags on parent commands + if c.Root().TraverseChildren { + finalCmd, finalArgs, err = c.Root().Traverse(trimmedArgs) + } else { + finalCmd, finalArgs, err = c.Root().Find(trimmedArgs) + } if err != nil { // Unable to find the real command. E.g., someInvalidCmd return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs) @@ -271,20 +279,24 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi var completions []string directive := ShellCompDirectiveDefault if flag == nil { - // Check if there are any local, non-persistent flags on the command-line foundLocalNonPersistentFlag := false - localNonPersistentFlags := finalCmd.LocalNonPersistentFlags() - finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { - if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed { - foundLocalNonPersistentFlag = true - } - }) + // If TraverseChildren is true on the root command we don't check for + // local flags because we can use a local flag on a parent command + if !finalCmd.Root().TraverseChildren { + // Check if there are any local, non-persistent flags on the command-line + localNonPersistentFlags := finalCmd.LocalNonPersistentFlags() + finalCmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { + if localNonPersistentFlags.Lookup(flag.Name) != nil && flag.Changed { + foundLocalNonPersistentFlag = true + } + }) + } // Complete subcommand names, including the help command if len(finalArgs) == 0 && !foundLocalNonPersistentFlag { // We only complete sub-commands if: // - there are no arguments on the command-line and - // - there are no local, non-peristent flag on the command-line + // - there are no local, non-peristent flag on the command-line or TraverseChildren is true for _, subCmd := range finalCmd.Commands() { if subCmd.IsAvailableCommand() || subCmd == finalCmd.helpCommand { if strings.HasPrefix(subCmd.Name(), toComplete) { diff --git a/custom_completions_test.go b/custom_completions_test.go index 276b8a77..fbc6a536 100644 --- a/custom_completions_test.go +++ b/custom_completions_test.go @@ -139,6 +139,8 @@ func TestNoCmdNameCompletionInGo(t *testing.T) { Use: "root", Run: emptyRun, } + rootCmd.Flags().String("localroot", "", "local root flag") + childCmd1 := &Command{ Use: "childCmd1", Short: "First command", @@ -187,6 +189,64 @@ func TestNoCmdNameCompletionInGo(t *testing.T) { t.Errorf("expected: %q, got: %q", expected, output) } + // Test that sub-command names are completed if a local non-persistent flag is present and TraverseChildren is set to true + // set TraverseChildren to true on the root cmd + rootCmd.TraverseChildren = true + + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + // Reset TraverseChildren for next command + rootCmd.TraverseChildren = false + + expected = strings.Join([]string{ + "childCmd1", + "help", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + + // Test that sub-command names from a child cmd are completed if a local non-persistent flag is present + // and TraverseChildren is set to true on the root cmd + rootCmd.TraverseChildren = true + + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "childCmd1", "--nonPersistent", "value", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + // Reset TraverseChildren for next command + rootCmd.TraverseChildren = false + // Reset the flag for the next command + nonPersistentFlag.Changed = false + + expected = strings.Join([]string{ + "childCmd2", + ":4", + "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + + // Test that we don't use Traverse when we shouldn't. + // This command should not return a completion since the command line is invalid without TraverseChildren. + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--localroot", "value", "childCmd1", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected = strings.Join([]string{ + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } + // Test that sub-command names are not completed if a local non-persistent short flag is present output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "childCmd1", "-n", "value", "") if err != nil {