Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(464): add support for queueing async actions in background #544

Merged
merged 4 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ linters-settings:
- "golang.org/x/text/cases"
- "golang.org/x/text/language"
- "gopkg.in/yaml.v2"
- "github.com/redis/go-redis/v9"
test:
files:
- $test
Expand All @@ -92,6 +93,8 @@ linters-settings:
- "github.com/knadh/koanf"
- "github.com/spf13/cast"
- "github.com/jackc/pgx/v5/pgproto3"
- "github.com/testcontainers/testcontainers-go"
- "github.com/redis/go-redis/v9"
tagalign:
align: false
sort: false
Expand Down
22 changes: 22 additions & 0 deletions act/act_helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package act

import (
"context"
"testing"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/redis"
)

func createWaitActEntities(async bool) (
Expand Down Expand Up @@ -49,3 +54,20 @@ func createWaitActEntities(async bool) (

return name, actions, signals, policy
}

func createTestRedis(t *testing.T) string {
t.Helper()
ctx := context.Background()

redisContainer, err := redis.RunContainer(ctx, testcontainers.WithImage("redis:6"))

assert.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, redisContainer.Terminate(ctx))
})
host, err := redisContainer.Host(ctx)
assert.NoError(t, err)
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
assert.NoError(t, err)
return host + ":" + port.Port()
}
40 changes: 40 additions & 0 deletions act/publisher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package act

import (
"context"
"fmt"

"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
)

type IPublisher interface {
Publish(ctx context.Context, payload []byte) error
}

var _ IPublisher = (*Publisher)(nil)

type Publisher struct {
Logger zerolog.Logger
RedisDB redis.Cmdable
ChannelName string
}

func NewPublisher(publisher Publisher) (*Publisher, error) {
if err := publisher.RedisDB.Ping(context.Background()).Err(); err != nil {
publisher.Logger.Err(err).Msg("failed to connect redis")
}
return &Publisher{
Logger: publisher.Logger,
RedisDB: publisher.RedisDB,
ChannelName: publisher.ChannelName,
}, nil
}

func (p *Publisher) Publish(ctx context.Context, payload []byte) error {
if err := p.RedisDB.Publish(ctx, p.ChannelName, payload).Err(); err != nil {
p.Logger.Err(err).Str("ChannelName", p.ChannelName).Msg("failed to publish task to redis")
return fmt.Errorf("failed to publish task to redis: %w", err)
}
return nil
}
113 changes: 111 additions & 2 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package act

import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"

Expand All @@ -26,6 +28,10 @@ type Registry struct {
// Default timeout for running actions
DefaultActionTimeout time.Duration

// TaskPublisher is the publisher for async actions.
// if not given, will invoke simple goroutine to run async actions
TaskPublisher *Publisher
mostafa marked this conversation as resolved.
Show resolved Hide resolved

Signals map[string]*sdkAct.Signal
Policies map[string]*sdkAct.Policy
Actions map[string]*sdkAct.Action
Expand All @@ -34,6 +40,27 @@ type Registry struct {
DefaultSignal *sdkAct.Signal
}

type AsyncActionMessage struct {
Output *sdkAct.Output
Params []sdkAct.Parameter
}

// Encode marshals the AsyncActionMessage struct to JSON bytes.
func (msg *AsyncActionMessage) Encode() ([]byte, error) {
marshaled, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("error encoding JSON: %w", err)
}
return marshaled, nil
}

func (msg *AsyncActionMessage) Decode(data []byte) error {
if err := json.Unmarshal(data, msg); err != nil {
return fmt.Errorf("error decoding JSON: %w", err)
}
return nil
}

var _ IRegistry = (*Registry)(nil)

// NewActRegistry creates a new act registry with the specified default policy and timeout
Expand Down Expand Up @@ -88,6 +115,7 @@ func NewActRegistry(
Actions: registry.Actions,
DefaultPolicy: registry.Policies[registry.DefaultPolicyName],
DefaultSignal: registry.Signals[registry.DefaultPolicyName],
TaskPublisher: registry.TaskPublisher,
}
}

Expand Down Expand Up @@ -234,6 +262,18 @@ func (r *Registry) Run(
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}

// if task is async and publisher is configured, publish it and do not run it
if r.TaskPublisher != nil && !action.Sync {
err := r.publishTask(output, params)
if err != nil {
r.Logger.Error().Err(err).Msg("Error publishing async action")
return nil, gerr.ErrPublishingAsyncAction
}
return nil, gerr.ErrAsyncAction
}

// no publisher, or sync action. run the action
var ctx context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
Expand All @@ -248,14 +288,83 @@ func (r *Registry) Run(
return runActionWithTimeout(ctx, action, output, params, r.Logger)
}

// Run the action asynchronously.
// If the action is asynchronous, run it in a goroutine and return the sentinel error.
go func() {
defer cancel()
_, _ = runActionWithTimeout(ctx, action, output, params, r.Logger)
}()

return nil, gerr.ErrAsyncAction
}

func (r *Registry) publishTask(output *sdkAct.Output, params []sdkAct.Parameter) error {
r.Logger.Debug().Msg("Publishing async action")
task := AsyncActionMessage{
Output: output,
Params: params,
}
payload, err := task.Encode()
if err != nil {
return err
}
if err := r.TaskPublisher.Publish(context.Background(), payload); err != nil {
return fmt.Errorf("error publishing task: %w", err)
}
return nil
}

func (r *Registry) runAsyncActionFn(ctx context.Context, message []byte) error {
msg := &AsyncActionMessage{}
if err := msg.Decode(message); err != nil {
r.Logger.Error().Err(err).Msg("Error decoding message")
return err
}
output := msg.Output
params := msg.Params

// In certain cases, the output may be nil, for example, if the policy
// evaluation fails. In this case, the run is aborted.
if output == nil {
// This should never happen, since the output is always set by the registry
// to be the default policy if no signals are provided.
r.Logger.Debug().Msg("Output is nil, run aborted")
return gerr.ErrNilPointer
}

action, ok := r.Actions[output.MatchedPolicy]
if !ok {
r.Logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg(
"Action does not exist, run aborted")
return gerr.ErrActionNotExist
}

// Prepend the logger to the parameters if needed.
if len(params) == 0 || params[0].Key != LoggerKey {
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)
} else {
params[0] = WithLogger(r.Logger)
}

timeout := r.DefaultActionTimeout
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}
var ctxWithTimeout context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
if timeout > 0 {
ctxWithTimeout, cancel = context.WithTimeout(ctx, timeout)
} else {
ctxWithTimeout, cancel = context.WithCancel(ctx)
}
// If the action is synchronous, run it and return the result immediately.
defer cancel()
if _, err := runActionWithTimeout(ctxWithTimeout, action, output, params, r.Logger); err != nil {
return err
}
return nil
}

func runActionWithTimeout(
ctx context.Context,
action *sdkAct.Action,
Expand Down Expand Up @@ -293,7 +402,7 @@ func runActionWithTimeout(
}
}

// WithLogger returns a parameter with the logger to be used by the action.
// WithLogger returns a parameter with the Logger to be used by the action.
// This is automatically prepended to the parameters when running an action.
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
return sdkAct.Parameter{
Expand Down
86 changes: 86 additions & 0 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package act

import (
"bytes"
"context"
"sync"
"testing"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/hashicorp/go-hclog"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -705,6 +709,88 @@ func Test_Run_Async(t *testing.T) {
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
}

// Test_Run_Async tests the Run function of the act registry with an asynchronous action.
func Test_Run_Async_Redis(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
hclogger := hclog.New(&hclog.LoggerOptions{
Output: &out,
Level: hclog.Debug,
JSONFormat: true,
})

rdbAddr := createTestRedis(t)
rdb := redis.NewClient(&redis.Options{
Addr: rdbAddr,
})
publisher, err := NewPublisher(Publisher{
Logger: logger,
RedisDB: rdb,
ChannelName: "test-async-chan",
})
require.NoError(t, err)

var waitGroup sync.WaitGroup
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
TaskPublisher: publisher,
})
assert.NotNil(t, actRegistry)

consumer, err := sdkAct.NewConsumer(hclogger, rdb, 5, "test-async-chan")
require.NoError(t, err)

require.NoError(t, consumer.Subscribe(context.Background(), func(ctx context.Context, task []byte) error {
err := actRegistry.runAsyncActionFn(ctx, task)
waitGroup.Done()
return err
}))

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Log("info", "test", map[string]any{"async": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
assert.Equal(t,
map[string]interface{}{
"async": true,
"level": "info",
"log": true,
"message": "test",
},
outputs[0].Metadata,
)
assert.False(t, outputs[0].Sync)
assert.True(t, cast.ToBool(outputs[0].Verdict))
assert.False(t, outputs[0].Terminal)
waitGroup.Add(1)
result, err := actRegistry.Run(outputs[0], WithResult(map[string]any{"key": "value"}))
waitGroup.Wait()
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
assert.Nil(t, result, "expected nil result")

time.Sleep(time.Millisecond) // wait for async action to complete

// The following is the expected log output from running the async action.
assert.Contains(t, out.String(), "{\"level\":\"debug\",\"action\":\"log\",\"executionMode\":\"async\",\"message\":\"Running action\"}") //nolint:lll
// The following is the expected log output from the run function of the async action.
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
// The following is expected log from consumer in hclog format
assert.Contains(t, out.String(), "\"@level\":\"debug\",\"@message\":\"async redis task processed successfully\"")
}

// Test_Run_NilRegistry tests the Run function of the action with a nil output object.
func Test_Run_NilOutput(t *testing.T) {
buf := bytes.Buffer{}
Expand Down
Loading