Skip to content

Commit

Permalink
Refactor proxy.shouldTerminate function and move the functionality to…
Browse files Browse the repository at this point in the history
… Act.Registry

Simplify syntax and avoid unnecessary loops (using maps.Keys)
Create RunAll and ShouldTerminate functions in Act.Registry
  • Loading branch information
mostafa committed Oct 6, 2024
1 parent e8596c5 commit 3f35e13
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 35 deletions.
40 changes: 40 additions & 0 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/rs/zerolog"
"github.com/spf13/cast"
)

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
RunAll(result map[string]any) map[string]any
ShouldTerminate(result map[string]any) bool
}

// Registry keeps track of all policies and actions.
Expand Down Expand Up @@ -402,6 +405,43 @@ func runActionWithTimeout(
}
}

// RunAll run all the actions in the outputs and returns the end result.
func (r *Registry) RunAll(result map[string]any) map[string]any {
outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output)
if !ok {
r.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type")
return nil
}

endResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
r.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
runResult, err := r.Run(output, WithResult(result))
// If the action is async and we received a sentinel error, don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
r.Logger.Error().Err(err).Msg("Error running policy")
}
// Each action should return a map.
if v, ok := runResult.(map[string]any); ok {
endResult = v
}
}
return endResult
}

// ShouldTerminate checks if any of the actions are terminal, indicating that the request
// should be terminated.
// This is an optimization to avoid executing the actions' functions unnecessarily.
// The __terminal__ field is only set when an action intends to terminate the request.
func (r *Registry) ShouldTerminate(result map[string]any) bool {
_, ok := result[sdkAct.Terminal]
return ok && cast.ToBool(result[sdkAct.Terminal])
}

// WithLogger returns a parameter with the Logger to be used by the action.
// This is automatically prepended to the parameters when running an action.
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
Expand Down
37 changes: 2 additions & 35 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ import (
"errors"
"io"
"net"
"slices"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/gatewayd-io/gatewayd/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/gatewayd-io/gatewayd/metrics"
Expand All @@ -21,9 +18,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/go-co-op/gocron"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"go.opentelemetry.io/otel"
"golang.org/x/exp/maps"
)

//nolint:interfacebloat
Expand Down Expand Up @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) {
return false, result
}

outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output)
if !ok {
pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type")
return false, result
}

// This is a shortcut to avoid running the actions' functions.
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]any); ok {
actionResult = v
}
}
terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result)
actionResult := pr.PluginRegistry.ActRegistry.RunAll(result)
if terminate {
pr.Logger.Debug().Fields(
map[string]any{
Expand Down

0 comments on commit 3f35e13

Please sign in to comment.