Skip to content

Commit

Permalink
feat(464): add support for queueing async actions in background
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamsajj committed May 27, 2024
1 parent d86f246 commit 96ee363
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 28 deletions.
3 changes: 3 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ linters-settings:
- "golang.org/x/text/cases"
- "golang.org/x/text/language"
- "gopkg.in/yaml.v2"
- "github.com/redis/go-redis/v9"
test:
files:
- $test
Expand All @@ -92,6 +93,8 @@ linters-settings:
- "github.com/knadh/koanf"
- "github.com/spf13/cast"
- "github.com/jackc/pgx/v5/pgproto3"
- "github.com/testcontainers/testcontainers-go"
- "github.com/redis/go-redis/v9"
tagalign:
align: false
sort: false
Expand Down
22 changes: 22 additions & 0 deletions act/act_helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package act

import (
"context"
"testing"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/redis"
)

func createWaitActEntities(async bool) (
Expand Down Expand Up @@ -49,3 +54,20 @@ func createWaitActEntities(async bool) (

return name, actions, signals, policy
}

func createTestRedis(t *testing.T) string {
t.Helper()
ctx := context.Background()

redisContainer, err := redis.RunContainer(ctx, testcontainers.WithImage("redis:6"))

assert.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, redisContainer.Terminate(ctx))
})
host, err := redisContainer.Host(ctx)
assert.NoError(t, err)
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
assert.NoError(t, err)
return host + ":" + port.Port()
}
40 changes: 40 additions & 0 deletions act/publisher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package act

import (
"context"
"fmt"

"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
)

type IPublisher interface {
Publish(ctx context.Context, payload []byte) error
}

var _ IPublisher = (*Publisher)(nil)

type Publisher struct {
Logger zerolog.Logger
RedisDB redis.Cmdable
ChannelName string
}

func NewPublisher(publisher Publisher) (*Publisher, error) {
if err := publisher.RedisDB.Ping(context.Background()).Err(); err != nil {
publisher.Logger.Err(err).Msg("failed to connect redis")
}
return &Publisher{
Logger: publisher.Logger,
RedisDB: publisher.RedisDB,
ChannelName: publisher.ChannelName,
}, nil
}

func (p *Publisher) Publish(ctx context.Context, payload []byte) error {
if err := p.RedisDB.Publish(ctx, p.ChannelName, payload).Err(); err != nil {
p.Logger.Err(err).Str("ChannelName", p.ChannelName).Msg("failed to publish task to redis")
return fmt.Errorf("failed to publish task to redis: %w", err)
}
return nil
}
113 changes: 111 additions & 2 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package act

import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"

Expand All @@ -26,6 +28,10 @@ type Registry struct {
// Default timeout for running actions
DefaultActionTimeout time.Duration

// TaskPublisher is the publisher for async actions.
// if not given, will invoke simple goroutine to run async actions
TaskPublisher IPublisher

Signals map[string]*sdkAct.Signal
Policies map[string]*sdkAct.Policy
Actions map[string]*sdkAct.Action
Expand All @@ -34,6 +40,27 @@ type Registry struct {
DefaultSignal *sdkAct.Signal
}

type AsyncActionMessage struct {
Output *sdkAct.Output
Params []sdkAct.Parameter
}

// Encode marshals the AsyncActionMessage struct to JSON bytes.
func (msg *AsyncActionMessage) Encode() ([]byte, error) {
marshaled, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("error encoding JSON: %w", err)
}
return marshaled, nil
}

func (msg *AsyncActionMessage) Decode(data []byte) error {
if err := json.Unmarshal(data, msg); err != nil {
return fmt.Errorf("error decoding JSON: %w", err)
}
return nil
}

var _ IRegistry = (*Registry)(nil)

// NewActRegistry creates a new act registry with the specified default policy and timeout
Expand Down Expand Up @@ -88,6 +115,7 @@ func NewActRegistry(
Actions: registry.Actions,
DefaultPolicy: registry.Policies[registry.DefaultPolicyName],
DefaultSignal: registry.Signals[registry.DefaultPolicyName],
TaskPublisher: registry.TaskPublisher,
}
}

Expand Down Expand Up @@ -234,6 +262,18 @@ func (r *Registry) Run(
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}

// if task is async and publisher is configured, publish it and do not run it
if r.TaskPublisher != nil && !action.Sync {
err := r.publishTask(output, params)
if err != nil {
r.Logger.Error().Err(err).Msg("Error publishing async action")
return nil, gerr.ErrPublishingAsyncAction
}
return nil, gerr.ErrAsyncAction
}

// no publisher, or sync action. run the action
var ctx context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
Expand All @@ -248,14 +288,83 @@ func (r *Registry) Run(
return runActionWithTimeout(ctx, action, output, params, r.Logger)
}

// Run the action asynchronously.
// If the action is asynchronous, run it in a goroutine and return the sentinel error.
go func() {
defer cancel()
_, _ = runActionWithTimeout(ctx, action, output, params, r.Logger)
}()

return nil, gerr.ErrAsyncAction
}

func (r *Registry) publishTask(output *sdkAct.Output, params []sdkAct.Parameter) error {
r.Logger.Debug().Msg("Publishing async action")
task := AsyncActionMessage{
Output: output,
Params: params,
}
payload, err := task.Encode()
if err != nil {
return err
}
if err := r.TaskPublisher.Publish(context.Background(), payload); err != nil {
return fmt.Errorf("error publishing task: %w", err)
}
return nil
}

func (r *Registry) runAsyncActionFn(ctx context.Context, message []byte) error {
msg := &AsyncActionMessage{}
if err := msg.Decode(message); err != nil {
r.Logger.Error().Err(err).Msg("Error decoding message")
return err
}
output := msg.Output
params := msg.Params

// In certain cases, the output may be nil, for example, if the policy
// evaluation fails. In this case, the run is aborted.
if output == nil {
// This should never happen, since the output is always set by the registry
// to be the default policy if no signals are provided.
r.Logger.Debug().Msg("Output is nil, run aborted")
return gerr.ErrNilPointer
}

action, ok := r.Actions[output.MatchedPolicy]
if !ok {
r.Logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg(
"Action does not exist, run aborted")
return gerr.ErrActionNotExist
}

// Prepend the logger to the parameters if needed.
if len(params) == 0 || params[0].Key != LoggerKey {
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)
} else {
params[0] = WithLogger(r.Logger)
}

timeout := r.DefaultActionTimeout
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}
var ctxWithTimeout context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
if timeout > 0 {
ctxWithTimeout, cancel = context.WithTimeout(ctx, timeout)
} else {
ctxWithTimeout, cancel = context.WithCancel(ctx)
}
// If the action is synchronous, run it and return the result immediately.
defer cancel()
if _, err := runActionWithTimeout(ctxWithTimeout, action, output, params, r.Logger); err != nil {
return err
}
return nil
}

func runActionWithTimeout(
ctx context.Context,
action *sdkAct.Action,
Expand Down Expand Up @@ -293,7 +402,7 @@ func runActionWithTimeout(
}
}

// WithLogger returns a parameter with the logger to be used by the action.
// WithLogger returns a parameter with the Logger to be used by the action.
// This is automatically prepended to the parameters when running an action.
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
return sdkAct.Parameter{
Expand Down
80 changes: 80 additions & 0 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package act

import (
"bytes"
"context"
"sync"
"testing"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -705,6 +708,83 @@ func Test_Run_Async(t *testing.T) {
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
}

// Test_Run_Async tests the Run function of the act registry with an asynchronous action.
func Test_Run_Async_Redis(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)

rdbAddr := createTestRedis(t)
rdb := redis.NewClient(&redis.Options{
Addr: rdbAddr,
})
publisher, err := NewPublisher(Publisher{
Logger: logger,
RedisDB: rdb,
ChannelName: "test-async-chan",
})
require.NoError(t, err)

var waitGroup sync.WaitGroup
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
TaskPublisher: publisher,
})
assert.NotNil(t, actRegistry)

consumer, err := sdkAct.NewConsumer(&logger, rdb, 5, "test-async-chan")
require.NoError(t, err)

require.NoError(t, consumer.Subscribe(context.Background(), func(ctx context.Context, task []byte) error {
err := actRegistry.runAsyncActionFn(ctx, task)
waitGroup.Done()
return err
}))

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Log("info", "test", map[string]any{"async": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
assert.Equal(t,
map[string]interface{}{
"async": true,
"level": "info",
"log": true,
"message": "test",
},
outputs[0].Metadata,
)
assert.False(t, outputs[0].Sync)
assert.True(t, cast.ToBool(outputs[0].Verdict))
assert.False(t, outputs[0].Terminal)
waitGroup.Add(1)
result, err := actRegistry.Run(outputs[0], WithResult(map[string]any{"key": "value"}))
waitGroup.Wait()
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
assert.Nil(t, result, "expected nil result")

time.Sleep(time.Millisecond) // wait for async action to complete

// The following is the expected log output from running the async action.
assert.Contains(t, out.String(), "{\"level\":\"debug\",\"action\":\"log\",\"executionMode\":\"async\",\"message\":\"Running action\"}") //nolint:lll
// The following is the expected log output from the run function of the async action.
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
// The following is expected log from consumer
assert.Contains(t, out.String(), "\"message\":\"async redis task processed successfully\"")
}

// Test_Run_NilRegistry tests the Run function of the action with a nil output object.
func Test_Run_NilOutput(t *testing.T) {
buf := bytes.Buffer{}
Expand Down
Loading

0 comments on commit 96ee363

Please sign in to comment.