diff --git a/args.go b/args.go index ed1e70c..92e1dde 100644 --- a/args.go +++ b/args.go @@ -33,7 +33,7 @@ func legacyArgs(cmd *Command, args []string) error { // root command with subcommands, do subcommand checking. if !cmd.HasParent() && len(args) > 0 { - return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0])) + return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.SuggestFunc()(args[0])) } return nil } @@ -58,7 +58,7 @@ func OnlyValidArgs(cmd *Command, args []string) error { } for _, v := range args { if !stringInSlice(v, validArgs) { - return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0])) + return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.SuggestFunc()(args[0])) } } } diff --git a/command.go b/command.go index 1960294..abf3f06 100644 --- a/command.go +++ b/command.go @@ -180,6 +180,8 @@ type Command struct { helpCommand *Command // helpCommandGroupID is the group id for the helpCommand helpCommandGroupID string + // suggestFunc is suggest func defined by the user. + suggestFunc func(string) string // completionCommandGroupID is the group id for the completion command completionCommandGroupID string @@ -340,6 +342,10 @@ func (c *Command) SetHelpCommandGroupID(groupID string) { c.helpCommandGroupID = groupID } +func (c *Command) SetSuggestFunc(f func(string) string) { + c.suggestFunc = f +} + // SetCompletionCommandGroupID sets the group id of the completion command. func (c *Command) SetCompletionCommandGroupID(groupID string) { // completionCommandGroupID is used if no completion command is defined by the user @@ -477,6 +483,18 @@ func (c *Command) Help() error { return nil } +// SuggestFunc returns either the function set by SetSuggestFunc for this command +// or a parent, or it returns a function with default suggestion behavior. +func (c *Command) SuggestFunc() func(string) string { + if c.suggestFunc != nil && !c.DisableSuggestions { + return c.suggestFunc + } + if c.HasParent() { + return c.Parent().SuggestFunc() + } + return c.findSuggestions +} + // UsageString returns usage string. func (c *Command) UsageString() string { // Storing normal writers diff --git a/command_test.go b/command_test.go index 837b6b3..16df3a5 100644 --- a/command_test.go +++ b/command_test.go @@ -1393,6 +1393,48 @@ func TestSuggestions(t *testing.T) { } } +func TestCustomSuggestions(t *testing.T) { + templateWithCustomSuggestions := "Error: unknown command \"%s\" for \"root\"\nSome custom suggestion.\n\nRun 'root --help' for usage.\n" + templateWithDefaultSuggestions := "Error: unknown command \"%s\" for \"root\"\n\nDid you mean this?\n\t%s\n\nRun 'root --help' for usage.\n" + templateWithoutSuggestions := "Error: unknown command \"%s\" for \"root\"\nRun 'root --help' for usage.\n" + + for typo, suggestion := range map[string]string{"time": "times", "timse": "times"} { + for _, suggestionsDisabled := range []bool{true, false} { + for _, setCustomSuggest := range []bool{true, false} { + rootCmd := &Command{Use: "root", Run: emptyRun} + timesCmd := &Command{ + Use: "times", + Run: emptyRun, + } + rootCmd.AddCommand(timesCmd) + + rootCmd.DisableSuggestions = suggestionsDisabled + + if setCustomSuggest { + rootCmd.SetSuggestFunc(func(a string) string { + return "\nSome custom suggestion.\n" + }) + } + + var expected string + if suggestionsDisabled { + expected = fmt.Sprintf(templateWithoutSuggestions, typo) + } else if setCustomSuggest { + expected = fmt.Sprintf(templateWithCustomSuggestions, typo) + } else { + expected = fmt.Sprintf(templateWithDefaultSuggestions, typo, suggestion) + } + + output, _ := executeCommand(rootCmd, typo) + + if output != expected { + t.Errorf("Unexpected response.\nExpected:\n %q\nGot:\n %q\n", expected, output) + } + } + } + } +} + func TestCaseInsensitive(t *testing.T) { rootCmd := &Command{Use: "root", Run: emptyRun} childCmd := &Command{Use: "child", Run: emptyRun, Aliases: []string{"alternative"}}