From 3f35e13c2268307ddbb296fc0332250e1de2ffca Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sun, 6 Oct 2024 23:27:05 +0200 Subject: [PATCH] Refactor proxy.shouldTerminate function and move the functionality to Act.Registry Simplify syntax and avoid unnecessary loops (using maps.Keys) Create RunAll and ShouldTerminate functions in Act.Registry --- act/registry.go | 40 ++++++++++++++++++++++++++++++++++++++++ network/proxy.go | 37 ++----------------------------------- 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/act/registry.go b/act/registry.go index c4cab57a..2e2e9cf1 100644 --- a/act/registry.go +++ b/act/registry.go @@ -12,12 +12,15 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/rs/zerolog" + "github.com/spf13/cast" ) type IRegistry interface { Add(policy *sdkAct.Policy) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError) + RunAll(result map[string]any) map[string]any + ShouldTerminate(result map[string]any) bool } // Registry keeps track of all policies and actions. @@ -402,6 +405,43 @@ func runActionWithTimeout( } } +// RunAll run all the actions in the outputs and returns the end result. +func (r *Registry) RunAll(result map[string]any) map[string]any { + outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) + if !ok { + r.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") + return nil + } + + endResult := make(map[string]any) + for _, output := range outputs { + if !cast.ToBool(output.Verdict) { + r.Logger.Debug().Msg( + "Skipping the action, because the verdict of the policy execution is false") + continue + } + runResult, err := r.Run(output, WithResult(result)) + // If the action is async and we received a sentinel error, don't log the error. + if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { + r.Logger.Error().Err(err).Msg("Error running policy") + } + // Each action should return a map. + if v, ok := runResult.(map[string]any); ok { + endResult = v + } + } + return endResult +} + +// ShouldTerminate checks if any of the actions are terminal, indicating that the request +// should be terminated. +// This is an optimization to avoid executing the actions' functions unnecessarily. +// The __terminal__ field is only set when an action intends to terminate the request. +func (r *Registry) ShouldTerminate(result map[string]any) bool { + _, ok := result[sdkAct.Terminal] + return ok && cast.ToBool(result[sdkAct.Terminal]) +} + // WithLogger returns a parameter with the Logger to be used by the action. // This is automatically prepended to the parameters when running an action. func WithLogger(logger zerolog.Logger) sdkAct.Parameter { diff --git a/network/proxy.go b/network/proxy.go index 50ad47b4..cdf9b71f 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -6,13 +6,10 @@ import ( "errors" "io" "net" - "slices" "time" - sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" - "github.com/gatewayd-io/gatewayd/act" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/metrics" @@ -21,9 +18,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" "github.com/rs/zerolog" - "github.com/spf13/cast" "go.opentelemetry.io/otel" - "golang.org/x/exp/maps" ) //nolint:interfacebloat @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) { return false, result } - outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) - if !ok { - pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") - return false, result - } - - // This is a shortcut to avoid running the actions' functions. - // The Terminal field is only present if the action wants to terminate the request, - // that is the `__terminal__` field is set in one of the outputs. - keys := maps.Keys(result) - terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal]) - actionResult := make(map[string]any) - for _, output := range outputs { - if !cast.ToBool(output.Verdict) { - pr.Logger.Debug().Msg( - "Skipping the action, because the verdict of the policy execution is false") - continue - } - actRes, err := pr.PluginRegistry.ActRegistry.Run( - output, act.WithResult(result)) - // If the action is async and we received a sentinel error, - // don't log the error. - if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { - pr.Logger.Error().Err(err).Msg("Error running policy") - } - // The terminate action should return a map. - if v, ok := actRes.(map[string]any); ok { - actionResult = v - } - } + terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result) + actionResult := pr.PluginRegistry.ActRegistry.RunAll(result) if terminate { pr.Logger.Debug().Fields( map[string]any{