From dbc168cce19f2aa8bec628049965a61999e0c312 Mon Sep 17 00:00:00 2001 From: Marc Khouzam Date: Sun, 26 Jan 2025 06:33:50 -0500 Subject: [PATCH] Fix missing recursion Signed-off-by: Marc Khouzam --- cobra_test.go | 1 + command.go | 58 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/cobra_test.go b/cobra_test.go index 2ba6b3e..f1c5b0a 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -18,6 +18,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "text/template" diff --git a/command.go b/command.go index f753f37..6904bfb 100644 --- a/command.go +++ b/command.go @@ -437,10 +437,7 @@ func (c *Command) UsageFunc() (f func(*Command) error) { } return func(c *Command) error { c.mergePersistentFlags() - fn := defaultUsageFunc - if c.usageTemplate != nil { - fn = c.usageTemplate.fn - } + fn := c.getUsageTemplateFunc() err := fn(c.OutOrStderr(), c) if err != nil { c.PrintErrln(err) @@ -449,6 +446,19 @@ func (c *Command) UsageFunc() (f func(*Command) error) { } } +// getUsageTemplateFunc returns the usage template function for the command +// going up the command tree if necessary. +func (c *Command) getUsageTemplateFunc() func(w io.Writer, data interface{}) error { + if c.usageTemplate != nil { + return c.usageTemplate.fn + } + + if c.HasParent() { + return c.parent.getUsageTemplateFunc() + } + return defaultUsageFunc +} + // Usage puts out the usage for the command. // Used when a user provides invalid input. // Can be defined by user by overriding UsageFunc. @@ -467,10 +477,7 @@ func (c *Command) HelpFunc() func(*Command, []string) { } return func(c *Command, a []string) { c.mergePersistentFlags() - fn := defaultHelpFunc - if c.helpTemplate != nil { - fn = c.helpTemplate.fn - } + fn := c.getHelpTemplateFunc() // The help should be sent to stdout // See https://github.com/spf13/cobra/issues/1002 err := fn(c.OutOrStdout(), c) @@ -480,6 +487,20 @@ func (c *Command) HelpFunc() func(*Command, []string) { } } +// getHelpTemplateFunc returns the help template function for the command +// going up the command tree if necessary. +func (c *Command) getHelpTemplateFunc() func(w io.Writer, data interface{}) error { + if c.helpTemplate != nil { + return c.helpTemplate.fn + } + + if c.HasParent() { + return c.parent.getHelpTemplateFunc() + } + + return defaultHelpFunc +} + // Help puts out the help for the command. // Used when a user calls help [command]. // Can be defined by user by overriding HelpFunc. @@ -554,6 +575,7 @@ func (c *Command) NamePadding() int { } // UsageTemplate returns usage template for the command. +// This function is kept for backwards-compatibility reasons. func (c *Command) UsageTemplate() string { if c.usageTemplate != nil { return c.usageTemplate.tmpl @@ -566,6 +588,7 @@ func (c *Command) UsageTemplate() string { } // HelpTemplate return help template for the command. +// This function is kept for backwards-compatibility reasons. func (c *Command) HelpTemplate() string { if c.helpTemplate != nil { return c.helpTemplate.tmpl @@ -578,6 +601,7 @@ func (c *Command) HelpTemplate() string { } // VersionTemplate return version template for the command. +// This function is kept for backwards-compatibility reasons. func (c *Command) VersionTemplate() string { if c.versionTemplate != nil { return c.versionTemplate.tmpl @@ -589,6 +613,19 @@ func (c *Command) VersionTemplate() string { return defaultVersionTemplate } +// getVersionTemplateFunc returns the version template function for the command +// going up the command tree if necessary. +func (c *Command) getVersionTemplateFunc() func(w io.Writer, data interface{}) error { + if c.versionTemplate != nil { + return c.versionTemplate.fn + } + + if c.HasParent() { + return c.parent.getVersionTemplateFunc() + } + return defaultVersionFunc +} + // ErrPrefix return error message prefix for the command func (c *Command) ErrPrefix() string { if c.errPrefix != "" { @@ -893,10 +930,7 @@ func (c *Command) execute(a []string) (err error) { return err } if versionVal { - fn := defaultVersionFunc - if c.versionTemplate != nil { - fn = c.versionTemplate.fn - } + fn := c.getVersionTemplateFunc() err := fn(c.OutOrStdout(), c) if err != nil { c.Println(err)