Skip to content

Commit

Permalink
Refactor proxy.shouldTerminate function and move the functionality …
Browse files Browse the repository at this point in the history
…to `Act.Registry` (#615)

* Refactor proxy.shouldTerminate function and move the functionality to Act.Registry
* Simplify syntax and avoid unnecessary loops (using maps.Keys)
* Create RunAll and ShouldTerminate functions in Act.Registry
* Add logger to RunAll
* Add test for RunAll and ShouldTerminate
* Separate the check for existence of 'outputs' key from type check
* Add test case for failures
* Address comments by @sinadarbouy
  • Loading branch information
mostafa authored Oct 7, 2024
1 parent e8596c5 commit 8effd0f
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 35 deletions.
67 changes: 67 additions & 0 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/rs/zerolog"
"github.com/spf13/cast"
)

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
RunAll(result map[string]any) map[string]any
ShouldTerminate(result map[string]any) bool
}

// Registry keeps track of all policies and actions.
Expand Down Expand Up @@ -402,6 +405,70 @@ func runActionWithTimeout(
}
}

// RunAll run all the actions in the outputs and returns the end result.
func (r *Registry) RunAll(result map[string]any) map[string]any {
if _, exists := result[sdkAct.Outputs]; !exists {
r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is")
return result
}

var (
outputs []*sdkAct.Output
ok bool
)
if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 {
r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is")
// If the outputs are nil or empty, we should delete the key from the result.
delete(result, sdkAct.Outputs)
return result
}

endResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
r.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
runResult, err := r.Run(output, WithResult(result), WithLogger(r.Logger))
// If the action is async and we received a sentinel error, don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
r.Logger.Error().Err(err).Msg("Error running policy")
}
// Each action should return a map.
if v, ok := runResult.(map[string]any); ok {
endResult = v
} else {
r.Logger.Debug().Msg("Run result is not a map, skipping merging into end result.")
}
}
return endResult
}

// ShouldTerminate checks if any of the actions are terminal, indicating that the request
// should be terminated.
// This is an optimization to avoid executing the actions' functions unnecessarily.
// The __terminal__ field is only set when an action intends to terminate the request.
func (r *Registry) ShouldTerminate(result map[string]any) bool {
terminalVal, exists := result[sdkAct.Terminal]
if !exists {
r.Logger.Debug().Msg("Terminal key not found, request will continue.")
return false
}

shouldTerminate, ok := terminalVal.(bool)
if !ok {
r.Logger.Debug().Msg("Terminal key exists but cannot be cast to a boolean.")
return false
}

if shouldTerminate {
r.Logger.Debug().Msg("Request is marked as terminal. Terminating.")
}

return shouldTerminate
}

// WithLogger returns a parameter with the Logger to be used by the action.
// This is automatically prepended to the parameters when running an action.
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
Expand Down
84 changes: 84 additions & 0 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,87 @@ func Test_Run_Timeout(t *testing.T) {
})
}
}

// Test_RunAll_And_ShouldTerminate tests the RunAll function of the act registry
// with a terminal action (and signal).
func Test_RunAll_And_ShouldTerminate(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
*sdkAct.Log("info", "testing log via Act", map[string]any{"test": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)

// This is what the hook returns along with "request", "response" and other fields.
// These two keys and values should exist in the result after policy execution.
result := map[string]any{
sdkAct.Outputs: outputs,
sdkAct.Terminal: true,
}

assert.True(t, actRegistry.ShouldTerminate(result))

result = actRegistry.RunAll(result)

time.Sleep(time.Millisecond) // wait for async action to complete

assert.NotEmpty(t, result)
// Terminate action does nothing when run. It is just a signal to terminate.
assert.Contains(t, out.String(),
`{"level":"debug","action":"terminate","executionMode":"sync","message":"Running action"}`)
assert.Contains(t, out.String(),
`{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`)
assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`)
}

// Test_RunAll_Empty_Result tests the RunAll function of the act registry with an empty result.
func Test_RunAll_Empty_Result(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

results := []map[string]any{
{},
{
sdkAct.Outputs: false, // This is invalid, hence it will be removed.
},
}

for _, result := range results {
assert.False(t, actRegistry.ShouldTerminate(result))

result = actRegistry.RunAll(result)

time.Sleep(time.Millisecond) // wait for async action to complete

assert.Empty(t, result)
}
}
37 changes: 2 additions & 35 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ import (
"errors"
"io"
"net"
"slices"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/gatewayd-io/gatewayd/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/gatewayd-io/gatewayd/metrics"
Expand All @@ -21,9 +18,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/go-co-op/gocron"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"go.opentelemetry.io/otel"
"golang.org/x/exp/maps"
)

//nolint:interfacebloat
Expand Down Expand Up @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) {
return false, result
}

outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output)
if !ok {
pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type")
return false, result
}

// This is a shortcut to avoid running the actions' functions.
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]any); ok {
actionResult = v
}
}
terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result)
actionResult := pr.PluginRegistry.ActRegistry.RunAll(result)
if terminate {
pr.Logger.Debug().Fields(
map[string]any{
Expand Down

0 comments on commit 8effd0f

Please sign in to comment.