Skip to content

Commit

Permalink
feat: use golang-queue to run async actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamsajj committed Apr 28, 2024
1 parent 5147137 commit 1c16794
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 19 deletions.
2 changes: 2 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ linters-settings:
- "gopkg.in/natefinch/lumberjack.v2"
- "github.com/expr-lang/expr"
- "github.com/jackc/pgx/v5/pgproto3"
- "github.com/golang-queue/queue"
test:
files:
- $test
Expand All @@ -87,6 +88,7 @@ linters-settings:
- "github.com/spf13/cobra"
- "github.com/knadh/koanf"
- "github.com/spf13/cast"
- "github.com/golang-queue/queue"
tagalign:
align: false
sort: false
Expand Down
63 changes: 47 additions & 16 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package act

import (
"context"
"encoding/json"
"errors"
"slices"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/golang-queue/queue"
"github.com/rs/zerolog"
)

Expand All @@ -32,6 +34,20 @@ type Registry struct {
DefaultPolicyName string
DefaultPolicy *sdkAct.Policy
DefaultSignal *sdkAct.Signal
ActionQueue *queue.Queue
}

type asyncActionMessage struct {
Output *sdkAct.Output
Params []sdkAct.Parameter
}

func (j *asyncActionMessage) Bytes() []byte {
b, err := json.Marshal(j)
if err != nil {
panic(err)
}
return b
}

var _ IRegistry = (*Registry)(nil)
Expand Down Expand Up @@ -79,6 +95,11 @@ func NewActRegistry(

registry.Logger.Debug().Str("name", registry.DefaultPolicyName).Msg("Using default policy")

if registry.ActionQueue == nil {
registry.Logger.Warn().Msg("ActionQueue is nil, not creating registry")
return nil
}

return &Registry{
Logger: registry.Logger,
PolicyTimeout: registry.PolicyTimeout,
Expand All @@ -88,6 +109,7 @@ func NewActRegistry(
Actions: registry.Actions,
DefaultPolicy: registry.Policies[registry.DefaultPolicyName],
DefaultSignal: registry.Signals[registry.DefaultPolicyName],
ActionQueue: registry.ActionQueue,
}
}

Expand Down Expand Up @@ -225,32 +247,35 @@ func (r *Registry) Run(
return nil, gerr.ErrActionNotExist
}

// Prepend the logger to the parameters.
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)

timeout := r.DefaultActionTimeout
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}
var ctx context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
if timeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), timeout)
} else {
ctx, cancel = context.WithCancel(context.Background())
}

// If the action is synchronous, run it and return the result immediately.
if action.Sync {
// Prepend the logger to the parameters.
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)

var ctx context.Context
var cancel context.CancelFunc
// if timeout is zero, then the context should not have timeout
if timeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), timeout)
} else {
ctx, cancel = context.WithCancel(context.Background())
}
defer cancel()
return runActionWithTimeout(ctx, action, output, params, r.Logger)
}

// Run the action asynchronously.
go func() {
defer cancel()
_, _ = runActionWithTimeout(ctx, action, output, params, r.Logger)
}()
if err := r.ActionQueue.Queue(&asyncActionMessage{
Output: output,
Params: params,
}); err != nil {
return nil, gerr.ErrAsyncQueueFailed
}

return nil, gerr.ErrAsyncAction
}

Expand All @@ -261,6 +286,12 @@ func runActionWithTimeout(
params []sdkAct.Parameter,
logger zerolog.Logger,
) (any, *gerr.GatewayDError) {
defer func() {
// recover from panic if one occurred. Set err to nil otherwise.
if recover() != nil {
logger.Error().Str("action", action.Name).Msg("Action panicked")
}
}()
execMode := "sync"
if !action.Sync {
execMode = "async"
Expand Down
42 changes: 40 additions & 2 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/golang-queue/queue"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"github.com/stretchr/testify/assert"
Expand All @@ -27,6 +28,7 @@ func Test_NewRegistry(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)
assert.NotNil(t, actRegistry.Signals)
Expand Down Expand Up @@ -89,6 +91,7 @@ func Test_NewRegistry_NilPolicy(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.Nil(t, actRegistry)
assert.Contains(t, buf.String(), "Policy is nil, not adding")
Expand All @@ -110,6 +113,7 @@ func Test_NewRegistry_NilAction(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.Nil(t, actRegistry)
assert.Contains(t, buf.String(), "Action is nil, not adding")
Expand All @@ -126,6 +130,7 @@ func Test_Add(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: zerolog.Logger{},
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -151,6 +156,7 @@ func Test_Add_NilPolicy(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -173,6 +179,7 @@ func Test_Add_ExistentPolicy(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -193,6 +200,7 @@ func Test_Apply(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: zerolog.Logger{},
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -222,6 +230,7 @@ func Test_Apply_NoSignals(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -268,6 +277,7 @@ func Test_Apply_ContradictorySignals(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -313,6 +323,7 @@ func Test_Apply_ActionNotMatched(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -346,6 +357,7 @@ func Test_Apply_PolicyNotMatched(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -394,6 +406,7 @@ func Test_Apply_NonBoolPolicy(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -442,6 +455,7 @@ func Test_Apply_BadPolicy(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.Nil(t, actRegistry)
}
Expand All @@ -459,6 +473,7 @@ func Test_Run(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -484,6 +499,7 @@ func Test_Run_Terminate(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -508,6 +524,15 @@ func Test_Run_Terminate(t *testing.T) {
func Test_Run_Async(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
worker := NewActWorker(
Worker{
Logger: logger,
Actions: BuiltinActions(),
DefaultActionTimeout: config.DefaultActionTimeout,
},
)

workerQueue := queue.NewPool(1, queue.WithFn(worker.RunFunc()))
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Expand All @@ -517,6 +542,7 @@ func Test_Run_Async(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: workerQueue,
})
assert.NotNil(t, actRegistry)

Expand All @@ -542,7 +568,7 @@ func Test_Run_Async(t *testing.T) {
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
assert.Nil(t, result, "expected nil result")

time.Sleep(time.Millisecond) // wait for async action to complete
time.Sleep(time.Millisecond * 2000) // wait for async action to complete

// The following is the expected log output from running the async action.
assert.Contains(t, out.String(), "{\"level\":\"debug\",\"action\":\"log\",\"executionMode\":\"async\",\"message\":\"Running action\"}") //nolint:lll
Expand All @@ -563,6 +589,7 @@ func Test_Run_NilOutput(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand All @@ -585,6 +612,7 @@ func Test_Run_ActionNotExist(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
ActionQueue: &queue.Queue{},
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -635,6 +663,15 @@ func Test_Run_Timeout(t *testing.T) {
name, actions, signals, policies := createWaitActEntities(isAsync)
out := bytes.Buffer{}
logger := zerolog.New(&out)
worker := NewActWorker(
Worker{
Logger: logger,
Actions: actions,
DefaultActionTimeout: test.timeout,
},
)

workerQueue := queue.NewPool(1, queue.WithFn(worker.RunFunc()))
actRegistry := NewActRegistry(
Registry{
Signals: signals,
Expand All @@ -644,6 +681,7 @@ func Test_Run_Timeout(t *testing.T) {
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: test.timeout,
Logger: logger,
ActionQueue: workerQueue,
})
assert.NotNil(t, actRegistry)

Expand Down Expand Up @@ -673,7 +711,7 @@ func Test_Run_Timeout(t *testing.T) {
assert.Equal(t, test.expectedResult, result)

if isAsync {
time.Sleep(3 * time.Millisecond)
time.Sleep(2000 * time.Millisecond)
}
if test.expectedLog != "" {
assert.Contains(t, out.String(), test.expectedLog)
Expand Down
Loading

0 comments on commit 1c16794

Please sign in to comment.