diff --git a/README.md b/README.md index 87de501..a3d8511 100644 --- a/README.md +++ b/README.md @@ -307,8 +307,8 @@ func main() { ## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option -If a node in the grammar has a `BeforeReset(...)`, `BeforeResolve -(...)`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those +If a node in the CLI, or any of its embedded fields, has a `BeforeReset(...) error`, `BeforeResolve +(...) error`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those methods will be called before values are reset, before validation/assignment, and after validation/assignment, respectively. @@ -341,40 +341,6 @@ func main() { } ``` -Another example of using hooks is load the env-file: - -```go -package main - -import ( - "fmt" - "github.com/alecthomas/kong" - "github.com/joho/godotenv" -) - -type EnvFlag string - -// BeforeResolve loads env file. -func (c EnvFlag) BeforeReset(ctx *kong.Context, trace *kong.Path) error { - path := string(ctx.FlagValue(trace.Flag).(EnvFlag)) // nolint - path = kong.ExpandPath(path) - if err := godotenv.Load(path); err != nil { - return err - } - return nil -} - -var CLI struct { - EnvFile EnvFlag - Flag `env:"FLAG"` -} - -func main() { - _ = kong.Parse(&CLI) - fmt.Println(CLI.Flag) -} -``` - ## Flags Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default. diff --git a/callbacks.go b/callbacks.go index 1df975d..9733e91 100644 --- a/callbacks.go +++ b/callbacks.go @@ -68,6 +68,33 @@ func getMethod(value reflect.Value, name string) reflect.Value { return method } +// Get methods from the given value and any embedded fields. +func getMethods(value reflect.Value, name string) []reflect.Value { + // Collect all possible receivers + receivers := []reflect.Value{value} + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + t := value.Type() + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + fieldType := t.Field(i) + if fieldType.IsExported() && fieldType.Anonymous { + receivers = append(receivers, field) + } + } + } + // Search all receivers for methods + var methods []reflect.Value + for _, receiver := range receivers { + if method := getMethod(receiver, name); method.IsValid() { + methods = append(methods, method) + } + } + return methods +} + func callFunction(f reflect.Value, bindings bindings) error { if f.Kind() != reflect.Func { return fmt.Errorf("expected function, got %s", f.Type()) diff --git a/kong.go b/kong.go index 764a994..4f3bea2 100644 --- a/kong.go +++ b/kong.go @@ -361,16 +361,14 @@ func (k *Kong) applyHook(ctx *Context, name string) error { default: panic("unsupported Path") } - method := getMethod(value, name) - if !method.IsValid() { - continue - } - binds := k.bindings.clone() - binds.add(ctx, trace) - binds.add(trace.Node().Vars().CloneWith(k.vars)) - binds.merge(ctx.bindings) - if err := callFunction(method, binds); err != nil { - return err + for _, method := range getMethods(value, name) { + binds := k.bindings.clone() + binds.add(ctx, trace) + binds.add(trace.Node().Vars().CloneWith(k.vars)) + binds.merge(ctx.bindings) + if err := callFunction(method, binds); err != nil { + return err + } } } // Path[0] will always be the app root. @@ -392,13 +390,11 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { continue } - method := getMethod(flag.Target, name) - if !method.IsValid() { - continue - } - path := &Path{Flag: flag} - if err := callFunction(method, binds.clone().add(path)); err != nil { - return next(err) + for _, method := range getMethods(flag.Target, name) { + path := &Path{Flag: flag} + if err := callFunction(method, binds.clone().add(path)); err != nil { + return next(err) + } } } return next(nil) diff --git a/kong_test.go b/kong_test.go index 2b52758..cd3fd66 100644 --- a/kong_test.go +++ b/kong_test.go @@ -2406,3 +2406,36 @@ func TestProviderMethods(t *testing.T) { err = kctx.Run(t) assert.NoError(t, err) } + +type EmbeddedCallback struct { + Embedded bool +} + +func (e *EmbeddedCallback) AfterApply() error { + e.Embedded = true + return nil +} + +type EmbeddedRoot struct { + EmbeddedCallback + Root bool +} + +func (e *EmbeddedRoot) AfterApply() error { + e.Root = true + return nil +} + +func TestEmbeddedCallbacks(t *testing.T) { + actual := &EmbeddedRoot{} + k := mustNew(t, actual) + _, err := k.Parse(nil) + assert.NoError(t, err) + expected := &EmbeddedRoot{ + EmbeddedCallback: EmbeddedCallback{ + Embedded: true, + }, + Root: true, + } + assert.Equal(t, expected, actual) +}