From 8effd0f2f4fdd3d0387ae281400c5cfc564fd493 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 19:00:19 +0200 Subject: [PATCH] Refactor `proxy.shouldTerminate` function and move the functionality to `Act.Registry` (#615) * Refactor proxy.shouldTerminate function and move the functionality to Act.Registry * Simplify syntax and avoid unnecessary loops (using maps.Keys) * Create RunAll and ShouldTerminate functions in Act.Registry * Add logger to RunAll * Add test for RunAll and ShouldTerminate * Separate the check for existence of 'outputs' key from type check * Add test case for failures * Address comments by @sinadarbouy --- act/registry.go | 67 +++++++++++++++++++++++++++++++++++ act/registry_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++ network/proxy.go | 37 ++----------------- 3 files changed, 153 insertions(+), 35 deletions(-) diff --git a/act/registry.go b/act/registry.go index c4cab57a..68fe97ab 100644 --- a/act/registry.go +++ b/act/registry.go @@ -12,12 +12,15 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/rs/zerolog" + "github.com/spf13/cast" ) type IRegistry interface { Add(policy *sdkAct.Policy) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError) + RunAll(result map[string]any) map[string]any + ShouldTerminate(result map[string]any) bool } // Registry keeps track of all policies and actions. @@ -402,6 +405,70 @@ func runActionWithTimeout( } } +// RunAll run all the actions in the outputs and returns the end result. +func (r *Registry) RunAll(result map[string]any) map[string]any { + if _, exists := result[sdkAct.Outputs]; !exists { + r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is") + return result + } + + var ( + outputs []*sdkAct.Output + ok bool + ) + if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 { + r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is") + // If the outputs are nil or empty, we should delete the key from the result. + delete(result, sdkAct.Outputs) + return result + } + + endResult := make(map[string]any) + for _, output := range outputs { + if !cast.ToBool(output.Verdict) { + r.Logger.Debug().Msg( + "Skipping the action, because the verdict of the policy execution is false") + continue + } + runResult, err := r.Run(output, WithResult(result), WithLogger(r.Logger)) + // If the action is async and we received a sentinel error, don't log the error. + if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { + r.Logger.Error().Err(err).Msg("Error running policy") + } + // Each action should return a map. + if v, ok := runResult.(map[string]any); ok { + endResult = v + } else { + r.Logger.Debug().Msg("Run result is not a map, skipping merging into end result.") + } + } + return endResult +} + +// ShouldTerminate checks if any of the actions are terminal, indicating that the request +// should be terminated. +// This is an optimization to avoid executing the actions' functions unnecessarily. +// The __terminal__ field is only set when an action intends to terminate the request. +func (r *Registry) ShouldTerminate(result map[string]any) bool { + terminalVal, exists := result[sdkAct.Terminal] + if !exists { + r.Logger.Debug().Msg("Terminal key not found, request will continue.") + return false + } + + shouldTerminate, ok := terminalVal.(bool) + if !ok { + r.Logger.Debug().Msg("Terminal key exists but cannot be cast to a boolean.") + return false + } + + if shouldTerminate { + r.Logger.Debug().Msg("Request is marked as terminal. Terminating.") + } + + return shouldTerminate +} + // WithLogger returns a parameter with the Logger to be used by the action. // This is automatically prepended to the parameters when running an action. func WithLogger(logger zerolog.Logger) sdkAct.Parameter { diff --git a/act/registry_test.go b/act/registry_test.go index d2f54bdb..53665764 100644 --- a/act/registry_test.go +++ b/act/registry_test.go @@ -930,3 +930,87 @@ func Test_Run_Timeout(t *testing.T) { }) } } + +// Test_RunAll_And_ShouldTerminate tests the RunAll function of the act registry +// with a terminal action (and signal). +func Test_RunAll_And_ShouldTerminate(t *testing.T) { + out := bytes.Buffer{} + logger := zerolog.New(&out) + actRegistry := NewActRegistry( + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + assert.NotNil(t, actRegistry) + + outputs := actRegistry.Apply([]sdkAct.Signal{ + *sdkAct.Terminate(), + *sdkAct.Log("info", "testing log via Act", map[string]any{"test": true}), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }) + assert.NotNil(t, outputs) + + // This is what the hook returns along with "request", "response" and other fields. + // These two keys and values should exist in the result after policy execution. + result := map[string]any{ + sdkAct.Outputs: outputs, + sdkAct.Terminal: true, + } + + assert.True(t, actRegistry.ShouldTerminate(result)) + + result = actRegistry.RunAll(result) + + time.Sleep(time.Millisecond) // wait for async action to complete + + assert.NotEmpty(t, result) + // Terminate action does nothing when run. It is just a signal to terminate. + assert.Contains(t, out.String(), + `{"level":"debug","action":"terminate","executionMode":"sync","message":"Running action"}`) + assert.Contains(t, out.String(), + `{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`) + assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`) +} + +// Test_RunAll_Empty_Result tests the RunAll function of the act registry with an empty result. +func Test_RunAll_Empty_Result(t *testing.T) { + out := bytes.Buffer{} + logger := zerolog.New(&out) + actRegistry := NewActRegistry( + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + assert.NotNil(t, actRegistry) + + results := []map[string]any{ + {}, + { + sdkAct.Outputs: false, // This is invalid, hence it will be removed. + }, + } + + for _, result := range results { + assert.False(t, actRegistry.ShouldTerminate(result)) + + result = actRegistry.RunAll(result) + + time.Sleep(time.Millisecond) // wait for async action to complete + + assert.Empty(t, result) + } +} diff --git a/network/proxy.go b/network/proxy.go index 50ad47b4..cdf9b71f 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -6,13 +6,10 @@ import ( "errors" "io" "net" - "slices" "time" - sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" - "github.com/gatewayd-io/gatewayd/act" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/metrics" @@ -21,9 +18,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" "github.com/rs/zerolog" - "github.com/spf13/cast" "go.opentelemetry.io/otel" - "golang.org/x/exp/maps" ) //nolint:interfacebloat @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) { return false, result } - outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) - if !ok { - pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") - return false, result - } - - // This is a shortcut to avoid running the actions' functions. - // The Terminal field is only present if the action wants to terminate the request, - // that is the `__terminal__` field is set in one of the outputs. - keys := maps.Keys(result) - terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal]) - actionResult := make(map[string]any) - for _, output := range outputs { - if !cast.ToBool(output.Verdict) { - pr.Logger.Debug().Msg( - "Skipping the action, because the verdict of the policy execution is false") - continue - } - actRes, err := pr.PluginRegistry.ActRegistry.Run( - output, act.WithResult(result)) - // If the action is async and we received a sentinel error, - // don't log the error. - if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { - pr.Logger.Error().Err(err).Msg("Error running policy") - } - // The terminate action should return a map. - if v, ok := actRes.(map[string]any); ok { - actionResult = v - } - } + terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result) + actionResult := pr.PluginRegistry.ActRegistry.RunAll(result) if terminate { pr.Logger.Debug().Fields( map[string]any{