From cf2b43dc4c78d83da3e73bdd875a051a66b67a74 Mon Sep 17 00:00:00 2001 From: Mohammad Ali Mehdizadeh Date: Tue, 9 Apr 2024 23:37:53 +0330 Subject: [PATCH] Passing struct to new functions (#503) * GatewayDError initialized directly, and NewGatewayDError deleted * Pass config.Config to NewConfig function * Pass Regsitry struct to NewActRegistry function * Pass GRPCServer struct to NewGRPCServer function * Pass Registry struct to NewRegistry function * Pass Merger struct to NewMerger function * Pass Retry struct to NewRetry function * Pass Proxy struct to NewProxy function * Pass ConnWrapper struct to NewConnWrapper function * Pass Server struct to NewServer function * solving some warnings of golang-ci-lint * Update dependencies and fix issues (#505) * Update deps * Remove unnecessary conversion * Use grpc.NewClient instead of the deprecated grpc.Dial * Remove copied loop variable * Use integer range * Remove unused nolint:wrapchecks * Fix variable usage * Use latest version of golangci-lint * Fix nil pointer error when creating a new gRPC server * Fix other minor issues * Improve Makefile by combining all build targets for different platforms into a single target (#507) * Fix issues found during code inspection (#508) * Fix issues found during code inspection * Fix issues reported by linters * Address comments --------- Signed-off-by: Mostafa Moradian Co-authored-by: mam Co-authored-by: Mostafa Moradian --- act/registry.go | 97 ++++++------- act/registry_test.go | 253 +++++++++++++++++++++++++-------- api/api.go | 4 +- api/api_test.go | 134 ++++++++++------- api/grpc_server.go | 18 +-- cmd/configs.go | 4 +- cmd/plugin_list.go | 2 +- cmd/run.go | 116 ++++++++------- config/config.go | 20 +-- config/config_test.go | 12 +- config/getters_test.go | 3 +- errors/errors.go | 185 ++++++++++++++---------- errors/gatewayd_error.go | 9 -- errors/gatewayd_error_test.go | 24 ---- metrics/merger.go | 6 +- metrics/merger_test.go | 4 +- network/conn_wrapper.go | 34 ++--- network/proxy.go | 201 +++++++++++++------------- network/proxy_test.go | 242 +++++++++++++++++++------------ network/retry.go | 24 ++-- network/retry_test.go | 14 +- network/server.go | 199 +++++++++++++------------- network/server_test.go | 75 +++++----- plugin/plugin_registry.go | 29 ++-- plugin/plugin_registry_test.go | 43 ++++-- plugin/utils_test.go | 12 +- 26 files changed, 1016 insertions(+), 748 deletions(-) delete mode 100644 errors/gatewayd_error_test.go diff --git a/act/registry.go b/act/registry.go index bff67d7d..1a45a85d 100644 --- a/act/registry.go +++ b/act/registry.go @@ -20,17 +20,18 @@ type IRegistry interface { // Registry keeps track of all policies and actions. type Registry struct { - logger zerolog.Logger + Logger zerolog.Logger // Timeout for policy evaluation. - policyTimeout time.Duration + PolicyTimeout time.Duration // Default timeout for running actions - defaultActionTimeout time.Duration - - Signals map[string]*sdkAct.Signal - Policies map[string]*sdkAct.Policy - Actions map[string]*sdkAct.Action - DefaultPolicy *sdkAct.Policy - DefaultSignal *sdkAct.Signal + DefaultActionTimeout time.Duration + + Signals map[string]*sdkAct.Signal + Policies map[string]*sdkAct.Policy + Actions map[string]*sdkAct.Action + DefaultPolicyName string + DefaultPolicy *sdkAct.Policy + DefaultSignal *sdkAct.Signal } var _ IRegistry = (*Registry)(nil) @@ -38,73 +39,67 @@ var _ IRegistry = (*Registry)(nil) // NewActRegistry creates a new act registry with the specified default policy and timeout // and the builtin signals, policies, and actions. func NewActRegistry( - builtinSignals map[string]*sdkAct.Signal, - builtinsPolicies map[string]*sdkAct.Policy, - builtinActions map[string]*sdkAct.Action, - defaultPolicy string, - policyTimeout time.Duration, - defaultActionTimeout time.Duration, - logger zerolog.Logger, + registry Registry, ) *Registry { - if builtinSignals == nil || builtinsPolicies == nil || builtinActions == nil { - logger.Warn().Msg("Builtin signals, policies, or actions are nil, not adding") + if registry.Signals == nil || registry.Policies == nil || registry.Actions == nil { + registry.Logger.Warn().Msg("Builtin signals, policies, or actions are nil, not adding") return nil } - for _, signal := range builtinSignals { + for _, signal := range registry.Signals { if signal == nil { - logger.Warn().Msg("Signal is nil, not adding") + registry.Logger.Warn().Msg("Signal is nil, not adding") return nil } - logger.Debug().Str("name", signal.Name).Msg("Registered builtin signal") + registry.Logger.Debug().Str("name", signal.Name).Msg("Registered builtin signal") } - for _, policy := range builtinsPolicies { + for _, policy := range registry.Policies { if policy == nil { - logger.Warn().Msg("Policy is nil, not adding") + registry.Logger.Warn().Msg("Policy is nil, not adding") return nil } - logger.Debug().Str("name", policy.Name).Msg("Registered builtin policy") + registry.Logger.Debug().Str("name", policy.Name).Msg("Registered builtin policy") } - for _, action := range builtinActions { + for _, action := range registry.Actions { if action == nil { - logger.Warn().Msg("Action is nil, not adding") + registry.Logger.Warn().Msg("Action is nil, not adding") return nil } - logger.Debug().Str("name", action.Name).Msg("Registered builtin action") + registry.Logger.Debug().Str("name", action.Name).Msg("Registered builtin action") } // The default policy must exist, otherwise use passthrough. - if _, exists := builtinsPolicies[defaultPolicy]; !exists || defaultPolicy == "" { - logger.Warn().Str("name", defaultPolicy).Msgf( + if _, exists := registry.Policies[registry.DefaultPolicyName]; !exists || registry.DefaultPolicyName == "" { + registry.Logger.Warn().Str("name", registry.DefaultPolicyName).Msgf( "The specified default policy does not exist, using %s", config.DefaultPolicy) - defaultPolicy = config.DefaultPolicy + registry.DefaultPolicyName = config.DefaultPolicy } - logger.Debug().Str("name", defaultPolicy).Msg("Using default policy") + registry.Logger.Debug().Str("name", registry.DefaultPolicyName).Msg("Using default policy") return &Registry{ - logger: logger, - policyTimeout: policyTimeout, - defaultActionTimeout: defaultActionTimeout, - Signals: builtinSignals, - Policies: builtinsPolicies, - Actions: builtinActions, - DefaultPolicy: builtinsPolicies[defaultPolicy], - DefaultSignal: builtinSignals[defaultPolicy], + Logger: registry.Logger, + PolicyTimeout: registry.PolicyTimeout, + DefaultActionTimeout: registry.DefaultActionTimeout, + Signals: registry.Signals, + Policies: registry.Policies, + Actions: registry.Actions, + DefaultPolicy: registry.Policies[registry.DefaultPolicyName], + DefaultSignal: registry.Signals[registry.DefaultPolicyName], } } // Add adds a policy to the registry. func (r *Registry) Add(policy *sdkAct.Policy) { if policy == nil { - r.logger.Warn().Msg("Policy is nil, not adding") + r.Logger.Warn().Msg("Policy is nil, not adding") return } if _, exists := r.Policies[policy.Name]; exists { - r.logger.Warn().Str("name", policy.Name).Msg("Policy already exists, overwriting") + r.Logger.Warn().Str("name", policy.Name).Msg("Policy already exists, overwriting") } // Builtin policies are can be overwritten by user-defined policies. @@ -115,7 +110,7 @@ func (r *Registry) Add(policy *sdkAct.Policy) { func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { // If there are no signals, apply the default policy. if len(signals) == 0 { - r.logger.Debug().Msg("No signals provided, applying default signal") + r.Logger.Debug().Msg("No signals provided, applying default signal") return r.Apply([]sdkAct.Signal{*r.DefaultSignal}) } @@ -138,7 +133,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { // If the signal is terminal, all non-terminal signals are ignored. Also, it only // makes sense to have a terminal signal if the action is synchronous and terminal. if len(terminal) > 0 && slices.Contains(nonTerminal, signal.Name) { - r.logger.Warn().Str("name", signal.Name).Msg( + r.Logger.Warn().Str("name", signal.Name).Msg( "Terminal signal takes precedence, ignoring non-terminal signals") continue } @@ -146,7 +141,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { // Apply the signal and append the output to the list of outputs. output, err := r.apply(signal) if err != nil { - r.logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal") + r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal") // If there is an error evaluating the policy, continue to the next signal. // This also prevents stack overflows from infinite loops of the external // if condition below. @@ -179,7 +174,7 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr } // Create a context with a timeout for policy evaluation. - ctx, cancel := context.WithTimeout(context.Background(), r.policyTimeout) + ctx, cancel := context.WithTimeout(context.Background(), r.PolicyTimeout) defer cancel() // Evaluate the policy. @@ -219,21 +214,21 @@ func (r *Registry) Run( if output == nil { // This should never happen, since the output is always set by the registry // to be the default policy if no signals are provided. - r.logger.Debug().Msg("Output is nil, run aborted") + r.Logger.Debug().Msg("Output is nil, run aborted") return nil, gerr.ErrNilPointer } action, ok := r.Actions[output.MatchedPolicy] if !ok { - r.logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg( + r.Logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg( "Action does not exist, run aborted") return nil, gerr.ErrActionNotExist } // Prepend the logger to the parameters. - params = append([]sdkAct.Parameter{WithLogger(r.logger)}, params...) + params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...) - timeout := r.defaultActionTimeout + timeout := r.DefaultActionTimeout if action.Timeout > 0 { timeout = time.Duration(action.Timeout) * time.Second } @@ -248,13 +243,13 @@ func (r *Registry) Run( // If the action is synchronous, run it and return the result immediately. if action.Sync { defer cancel() - return runActionWithTimeout(ctx, action, output, params, r.logger) + return runActionWithTimeout(ctx, action, output, params, r.Logger) } // Run the action asynchronously. go func() { defer cancel() - _, _ = runActionWithTimeout(ctx, action, output, params, r.logger) + _, _ = runActionWithTimeout(ctx, action, output, params, r.Logger) }() return nil, gerr.ErrAsyncAction } diff --git a/act/registry_test.go b/act/registry_test.go index b4442593..c5cc3e7c 100644 --- a/act/registry_test.go +++ b/act/registry_test.go @@ -19,8 +19,15 @@ func Test_NewRegistry(t *testing.T) { logger := zerolog.Logger{} actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) assert.NotNil(t, actRegistry.Signals) assert.NotNil(t, actRegistry.Policies) @@ -35,7 +42,12 @@ func Test_NewRegistry_NilBuiltins(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - nil, nil, nil, config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.Nil(t, actRegistry) assert.Contains(t, buf.String(), "Builtin signals, policies, or actions are nil, not adding") } @@ -46,11 +58,17 @@ func Test_NewRegistry_NilSignal(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - map[string]*sdkAct.Signal{ - "bad": nil, - }, - BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: map[string]*sdkAct.Signal{ + "bad": nil, + }, + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.Nil(t, actRegistry) assert.Contains(t, buf.String(), "Signal is nil, not adding") } @@ -61,12 +79,17 @@ func Test_NewRegistry_NilPolicy(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), - map[string]*sdkAct.Policy{ - "bad": nil, - }, - BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: map[string]*sdkAct.Policy{ + "bad": nil, + }, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.Nil(t, actRegistry) assert.Contains(t, buf.String(), "Policy is nil, not adding") } @@ -77,11 +100,17 @@ func Test_NewRegistry_NilAction(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), - map[string]*sdkAct.Action{ - "bad": nil, - }, - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: map[string]*sdkAct.Action{ + "bad": nil, + }, + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.Nil(t, actRegistry) assert.Contains(t, buf.String(), "Action is nil, not adding") } @@ -89,8 +118,15 @@ func Test_NewRegistry_NilAction(t *testing.T) { // Test_Add tests the Add function of the act registry. func Test_Add(t *testing.T) { actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, zerolog.Logger{}) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) assert.NotNil(t, actRegistry) assert.Len(t, actRegistry.Policies, len(BuiltinPolicies())) @@ -107,8 +143,15 @@ func Test_Add_NilPolicy(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), map[string]*sdkAct.Policy{}, BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: map[string]*sdkAct.Policy{}, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) assert.Len(t, actRegistry.Policies, 0) @@ -122,8 +165,15 @@ func Test_Add_ExistentPolicy(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) assert.Len(t, actRegistry.Policies, len(BuiltinPolicies())) @@ -135,8 +185,15 @@ func Test_Add_ExistentPolicy(t *testing.T) { // Test_Apply tests the Apply function of the act registry. func Test_Apply(t *testing.T) { actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, zerolog.Logger{}) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -157,8 +214,15 @@ func Test_Apply_NoSignals(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{}) @@ -196,8 +260,15 @@ func Test_Apply_ContradictorySignals(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) for _, s := range signals { @@ -234,8 +305,15 @@ func Test_Apply_ActionNotMatched(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -258,12 +336,17 @@ func Test_Apply_PolicyNotMatched(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), - map[string]*sdkAct.Policy{ - "passthrough": BuiltinPolicies()["passthrough"], - }, - BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: map[string]*sdkAct.Policy{ + "passthrough": BuiltinPolicies()["passthrough"], + }, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -303,10 +386,15 @@ func Test_Apply_NonBoolPolicy(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), - policies, - BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: policies, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -346,10 +434,15 @@ func Test_Apply_BadPolicy(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), - policies, - BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: policies, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.Nil(t, actRegistry) } } @@ -358,8 +451,15 @@ func Test_Apply_BadPolicy(t *testing.T) { func Test_Run(t *testing.T) { logger := zerolog.Logger{} actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -376,8 +476,15 @@ func Test_Run(t *testing.T) { func Test_Run_Terminate(t *testing.T) { logger := zerolog.Logger{} actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -402,8 +509,15 @@ func Test_Run_Async(t *testing.T) { out := bytes.Buffer{} logger := zerolog.New(&out) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{ @@ -441,8 +555,15 @@ func Test_Run_NilOutput(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) result, err := actRegistry.Run(nil, WithLogger(logger)) @@ -456,8 +577,15 @@ func Test_Run_ActionNotExist(t *testing.T) { buf := bytes.Buffer{} logger := zerolog.New(&buf) actRegistry := NewActRegistry( - BuiltinSignals(), BuiltinPolicies(), BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) result, err := actRegistry.Run(&sdkAct.Output{}, WithLogger(logger)) @@ -508,8 +636,15 @@ func Test_Run_Timeout(t *testing.T) { out := bytes.Buffer{} logger := zerolog.New(&out) actRegistry := NewActRegistry( - signals, policies, actions, - config.DefaultPolicy, config.DefaultPolicyTimeout, test.timeout, logger) + Registry{ + Signals: signals, + Policies: policies, + Actions: actions, + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: test.timeout, + Logger: logger, + }) assert.NotNil(t, actRegistry) outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]}) diff --git a/api/api.go b/api/api.go index dd8fa84e..96c1c6af 100644 --- a/api/api.go +++ b/api/api.go @@ -200,12 +200,12 @@ func (a *API) GetProxies(context.Context, *emptypb.Empty) (*structpb.Struct, err proxies := make(map[string]interface{}) for name, proxy := range a.Proxies { available := make([]interface{}, 0) - for _, c := range proxy.AvailableConnections() { + for _, c := range proxy.AvailableConnectionsString() { available = append(available, c) } busy := make([]interface{}, 0) - for _, conn := range proxy.BusyConnections() { + for _, conn := range proxy.BusyConnectionsString() { busy = append(busy, conn) } diff --git a/api/api_test.go b/api/api_test.go index 5ca35371..cc885164 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -28,7 +28,8 @@ func TestGetVersion(t *testing.T) { func TestGetGlobalConfig(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), "../gatewayd.yaml", "../gatewayd_plugins.yaml") + conf := config.NewConfig(context.TODO(), + config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) conf.InitConfig(context.TODO()) assert.NotEmpty(t, conf.Global) @@ -50,7 +51,8 @@ func TestGetGlobalConfig(t *testing.T) { func TestGetGlobalConfigWithGroupName(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), "../gatewayd.yaml", "../gatewayd_plugins.yaml") + conf := config.NewConfig(context.TODO(), + config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) conf.InitConfig(context.TODO()) assert.NotEmpty(t, conf.Global) @@ -76,7 +78,8 @@ func TestGetGlobalConfigWithGroupName(t *testing.T) { func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), "../gatewayd.yaml", "../gatewayd_plugins.yaml") + conf := config.NewConfig(context.TODO(), + config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) conf.InitConfig(context.TODO()) assert.NotEmpty(t, conf.Global) @@ -91,7 +94,8 @@ func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) { func TestGetPluginConfig(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), "../gatewayd.yaml", "../gatewayd_plugins.yaml") + conf := config.NewConfig(context.TODO(), + config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) conf.InitConfig(context.TODO()) assert.NotEmpty(t, conf.Global) @@ -107,14 +111,23 @@ func TestGetPluginConfig(t *testing.T) { func TestGetPlugins(t *testing.T) { actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, zerolog.Logger{}) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) pluginRegistry := plugin.NewRegistry( context.TODO(), - actRegistry, - config.Loose, - zerolog.Logger{}, - true, + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: zerolog.Logger{}, + DevMode: true, + }, ) pluginRegistry.Add(&plugin.Plugin{ ID: sdkPlugin.Identifier{ @@ -136,14 +149,23 @@ func TestGetPlugins(t *testing.T) { func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) { actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, zerolog.Logger{}) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) pluginRegistry := plugin.NewRegistry( context.TODO(), - actRegistry, - config.Loose, - zerolog.Logger{}, - true, + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: zerolog.Logger{}, + DevMode: true, + }, ) api := API{ @@ -190,15 +212,16 @@ func TestGetProxies(t *testing.T) { proxy := network.NewProxy( context.TODO(), - newPool, - nil, - config.DefaultHealthCheckPeriod, - &config.Client{ - Network: config.DefaultNetwork, - Address: config.DefaultAddress, + network.Proxy{ + AvailableConnections: newPool, + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &config.Client{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + }, + Logger: zerolog.Logger{}, + PluginTimeout: config.DefaultPluginTimeout, }, - zerolog.Logger{}, - config.DefaultPluginTimeout, ) api := API{ @@ -234,45 +257,54 @@ func TestGetServers(t *testing.T) { proxy := network.NewProxy( context.TODO(), - newPool, - nil, - config.DefaultHealthCheckPeriod, - &config.Client{ - Network: config.DefaultNetwork, - Address: config.DefaultAddress, + network.Proxy{ + AvailableConnections: newPool, + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &config.Client{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + }, + Logger: zerolog.Logger{}, + PluginTimeout: config.DefaultPluginTimeout, }, - zerolog.Logger{}, - config.DefaultPluginTimeout, ) actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, zerolog.Logger{}) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: zerolog.Logger{}, + }) pluginRegistry := plugin.NewRegistry( context.TODO(), - actRegistry, - config.Loose, - zerolog.Logger{}, - true, + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: zerolog.Logger{}, + DevMode: true, + }, ) server := network.NewServer( context.TODO(), - config.DefaultNetwork, - config.DefaultAddress, - config.DefaultTickInterval, - network.Option{ - EnableTicker: false, + network.Server{ + Network: config.DefaultNetwork, + Address: config.DefaultAddress, + TickInterval: config.DefaultTickInterval, + Options: network.Option{ + EnableTicker: false, + }, + Proxy: proxy, + Logger: zerolog.Logger{}, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, }, - proxy, - zerolog.Logger{}, - pluginRegistry, - config.DefaultPluginTimeout, - false, - "", - "", - config.DefaultHandshakeTimeout, ) api := API{ diff --git a/api/grpc_server.go b/api/grpc_server.go index 5f641f17..ffab79b4 100644 --- a/api/grpc_server.go +++ b/api/grpc_server.go @@ -11,29 +11,31 @@ import ( ) type GRPCServer struct { - api *API + API *API grpcServer *grpc.Server listener net.Listener + *HealthChecker } // NewGRPCServer creates a new gRPC server. -func NewGRPCServer(api *API, healthchecker *HealthChecker) *GRPCServer { - grpcServer, listener := createGRPCAPI(api, healthchecker) +func NewGRPCServer(server GRPCServer) *GRPCServer { + grpcServer, listener := createGRPCAPI(server.API, server.HealthChecker) if grpcServer == nil || listener == nil { - api.Options.Logger.Error().Msg("Failed to create gRPC API server and listener") + server.API.Options.Logger.Error().Msg("Failed to create gRPC API server and listener") return nil } return &GRPCServer{ - api: api, - grpcServer: grpcServer, - listener: listener, + API: server.API, + grpcServer: grpcServer, + listener: listener, + HealthChecker: server.HealthChecker, } } // Start starts the gRPC server. func (s *GRPCServer) Start() { - s.start(s.api, s.grpcServer, s.listener) + s.start(s.API, s.grpcServer, s.listener) } // Shutdown shuts down the gRPC server. diff --git a/cmd/configs.go b/cmd/configs.go index 93f9405d..ef867725 100644 --- a/cmd/configs.go +++ b/cmd/configs.go @@ -71,12 +71,12 @@ func lintConfig(fileType configFileType, configFile string) error { var conf *config.Config switch fileType { case Global: - conf = config.NewConfig(context.TODO(), configFile, "") + conf = config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: configFile}) conf.LoadDefaults(context.TODO()) conf.LoadGlobalConfigFile(context.TODO()) conf.UnmarshalGlobalConfig(context.TODO()) case Plugins: - conf = config.NewConfig(context.TODO(), "", configFile) + conf = config.NewConfig(context.TODO(), config.Config{PluginConfigFile: configFile}) conf.LoadDefaults(context.TODO()) conf.LoadPluginConfigFile(context.TODO()) conf.UnmarshalPluginConfig(context.TODO()) diff --git a/cmd/plugin_list.go b/cmd/plugin_list.go index 60e202b6..0384de91 100644 --- a/cmd/plugin_list.go +++ b/cmd/plugin_list.go @@ -56,7 +56,7 @@ func init() { func listPlugins(cmd *cobra.Command, pluginConfigFile string, onlyEnabled bool) { // Load the plugin config file. - conf := config.NewConfig(context.TODO(), "", pluginConfigFile) + conf := config.NewConfig(context.TODO(), config.Config{PluginConfigFile: pluginConfigFile}) conf.LoadDefaults(context.TODO()) conf.LoadPluginConfigFile(context.TODO()) conf.UnmarshalPluginConfig(context.TODO()) diff --git a/cmd/run.go b/cmd/run.go index 9d14be74..c788c0e9 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -234,7 +234,7 @@ var runCmd = &cobra.Command{ } // Load global and plugin configuration. - conf = config.NewConfig(runCtx, globalConfigFile, pluginConfigFile) + conf = config.NewConfig(runCtx, config.Config{GlobalConfigFile: globalConfigFile, PluginConfigFile: pluginConfigFile}) conf.InitConfig(runCtx) // Create and initialize loggers from the config. @@ -284,9 +284,15 @@ var runCmd = &cobra.Command{ // Create a new act registry given the built-in signals, policies, and actions. actRegistry = act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - conf.Plugin.DefaultPolicy, conf.Plugin.PolicyTimeout, conf.Plugin.ActionTimeout, logger, - ) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: conf.Plugin.DefaultPolicy, + PolicyTimeout: conf.Plugin.PolicyTimeout, + DefaultActionTimeout: conf.Plugin.ActionTimeout, + Logger: logger, + }) if actRegistry == nil { logger.Error().Msg("Failed to create act registry") @@ -312,15 +318,17 @@ var runCmd = &cobra.Command{ // The plugins are loaded and hooks registered before the configuration is loaded. pluginRegistry = plugin.NewRegistry( runCtx, - actRegistry, - config.If( - config.Exists( - config.CompatibilityPolicies, conf.Plugin.CompatibilityPolicy, - ), - config.CompatibilityPolicies[conf.Plugin.CompatibilityPolicy], - config.DefaultCompatibilityPolicy), - logger, - devMode, + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.If( + config.Exists( + config.CompatibilityPolicies, conf.Plugin.CompatibilityPolicy, + ), + config.CompatibilityPolicies[conf.Plugin.CompatibilityPolicy], + config.DefaultCompatibilityPolicy), + Logger: logger, + DevMode: devMode, + }, ) // Load plugins and register their hooks. @@ -329,7 +337,10 @@ var runCmd = &cobra.Command{ // Start the metrics merger if enabled. var metricsMerger *metrics.Merger if conf.Plugin.EnableMetricsMerger { - metricsMerger = metrics.NewMerger(runCtx, conf.Plugin.MetricsMergerPeriod, logger) + metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{ + MetricsMergerPeriod: conf.Plugin.MetricsMergerPeriod, + Logger: logger, + }) pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) { if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled { metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"]) @@ -655,15 +666,17 @@ var runCmd = &cobra.Command{ client := network.NewClient( runCtx, clientConfig, logger, network.NewRetry( - clientConfig.Retries, - config.If( - clientConfig.Backoff > 0, - clientConfig.Backoff, - config.DefaultBackoff, - ), - clientConfig.BackoffMultiplier, - clientConfig.DisableBackoffCaps, - loggers[name], + network.Retry{ + Retries: clientConfig.Retries, + Backoff: config.If( + clientConfig.Backoff > 0, + clientConfig.Backoff, + config.DefaultBackoff, + ), + BackoffMultiplier: clientConfig.BackoffMultiplier, + DisableBackoffCaps: clientConfig.DisableBackoffCaps, + Logger: loggers[name], + }, ), ) @@ -797,12 +810,14 @@ var runCmd = &cobra.Command{ proxies[name] = network.NewProxy( runCtx, - pools[name], - pluginRegistry, - cfg.HealthCheckPeriod, - clientConfig, - logger, - conf.Plugin.Timeout, + network.Proxy{ + AvailableConnections: pools[name], + PluginRegistry: pluginRegistry, + HealthCheckPeriod: cfg.HealthCheckPeriod, + ClientConfig: clientConfig, + Logger: logger, + PluginTimeout: conf.Plugin.Timeout, + }, ) span.AddEvent("Create proxy", trace.WithAttributes( @@ -834,25 +849,27 @@ var runCmd = &cobra.Command{ logger := loggers[name] servers[name] = network.NewServer( runCtx, - cfg.Network, - cfg.Address, - config.If( - cfg.TickInterval > 0, - cfg.TickInterval, - config.DefaultTickInterval, - ), - network.Option{ - // Can be used to send keepalive messages to the client. - EnableTicker: cfg.EnableTicker, + network.Server{ + Network: cfg.Network, + Address: cfg.Address, + TickInterval: config.If( + cfg.TickInterval > 0, + cfg.TickInterval, + config.DefaultTickInterval, + ), + Options: network.Option{ + // Can be used to send keepalive messages to the client. + EnableTicker: cfg.EnableTicker, + }, + Proxy: proxies[name], + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, }, - proxies[name], - logger, - pluginRegistry, - conf.Plugin.Timeout, - cfg.EnableTLS, - cfg.CertFile, - cfg.KeyFile, - cfg.HandshakeTimeout, ) span.AddEvent("Create server", trace.WithAttributes( @@ -903,7 +920,10 @@ var runCmd = &cobra.Command{ Proxies: proxies, Servers: servers, } - grpcServer = api.NewGRPCServer(apiObj, &api.HealthChecker{Servers: servers}) + grpcServer = api.NewGRPCServer(api.GRPCServer{ + API: apiObj, + HealthChecker: &api.HealthChecker{Servers: servers}, + }) if grpcServer != nil { go grpcServer.Start() logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API") diff --git a/config/config.go b/config/config.go index 1cf060a1..a38af813 100644 --- a/config/config.go +++ b/config/config.go @@ -36,8 +36,8 @@ type Config struct { globalDefaults GlobalConfig pluginDefaults PluginConfig - globalConfigFile string - pluginConfigFile string + GlobalConfigFile string + PluginConfigFile string GlobalKoanf *koanf.Koanf PluginKoanf *koanf.Koanf @@ -48,19 +48,19 @@ type Config struct { var _ IConfig = (*Config)(nil) -func NewConfig(ctx context.Context, globalConfigFile, pluginConfigFile string) *Config { +func NewConfig(ctx context.Context, config Config) *Config { _, span := otel.Tracer(TracerName).Start(ctx, "Create new config") defer span.End() - span.SetAttributes(attribute.String("globalConfigFile", globalConfigFile)) - span.SetAttributes(attribute.String("pluginConfigFile", pluginConfigFile)) + span.SetAttributes(attribute.String("GlobalConfigFile", config.GlobalConfigFile)) + span.SetAttributes(attribute.String("PluginConfigFile", config.PluginConfigFile)) return &Config{ GlobalKoanf: koanf.New("."), PluginKoanf: koanf.New("."), globalDefaults: GlobalConfig{}, pluginDefaults: PluginConfig{}, - globalConfigFile: globalConfigFile, - pluginConfigFile: pluginConfigFile, + GlobalConfigFile: config.GlobalConfigFile, + PluginConfigFile: config.PluginConfigFile, } } @@ -159,7 +159,7 @@ func (c *Config) LoadDefaults(ctx context.Context) { } //nolint:nestif - if contents, err := os.ReadFile(c.globalConfigFile); err == nil { + if contents, err := os.ReadFile(c.GlobalConfigFile); err == nil { gconf, err := yaml.Parser().Unmarshal(contents) if err != nil { span.RecordError(err) @@ -275,7 +275,7 @@ func loadEnvVars() *env.Env { func (c *Config) LoadGlobalConfigFile(ctx context.Context) { _, span := otel.Tracer(TracerName).Start(ctx, "Load global config file") - if err := c.GlobalKoanf.Load(file.Provider(c.globalConfigFile), yaml.Parser()); err != nil { + if err := c.GlobalKoanf.Load(file.Provider(c.GlobalConfigFile), yaml.Parser()); err != nil { span.RecordError(err) span.End() log.Fatal(fmt.Errorf("failed to load global configuration: %w", err)) @@ -288,7 +288,7 @@ func (c *Config) LoadGlobalConfigFile(ctx context.Context) { func (c *Config) LoadPluginConfigFile(ctx context.Context) { _, span := otel.Tracer(TracerName).Start(ctx, "Load plugin config file") - if err := c.PluginKoanf.Load(file.Provider(c.pluginConfigFile), yaml.Parser()); err != nil { + if err := c.PluginKoanf.Load(file.Provider(c.PluginConfigFile), yaml.Parser()); err != nil { span.RecordError(err) span.End() log.Fatal(fmt.Errorf("failed to load plugin configuration: %w", err)) diff --git a/config/config_test.go b/config/config_test.go index d3e15bca..b89a9212 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -13,10 +13,10 @@ var parentDir = "../" // TestNewConfig tests the NewConfig function. func TestNewConfig(t *testing.T) { config := NewConfig( - context.Background(), GlobalConfigFilename, PluginsConfigFilename) + context.Background(), Config{GlobalConfigFile: GlobalConfigFilename, PluginConfigFile: PluginsConfigFilename}) assert.NotNil(t, config) - assert.Equal(t, GlobalConfigFilename, config.globalConfigFile) - assert.Equal(t, PluginsConfigFilename, config.pluginConfigFile) + assert.Equal(t, GlobalConfigFilename, config.GlobalConfigFile) + assert.Equal(t, PluginsConfigFilename, config.PluginConfigFile) assert.Equal(t, GlobalConfig{}, config.globalDefaults) assert.Equal(t, PluginConfig{}, config.pluginDefaults) assert.Equal(t, GlobalConfig{}, config.Global) @@ -29,7 +29,8 @@ func TestNewConfig(t *testing.T) { // the other functions. func TestInitConfig(t *testing.T) { ctx := context.Background() - config := NewConfig(ctx, parentDir+GlobalConfigFilename, parentDir+PluginsConfigFilename) + config := NewConfig(ctx, + Config{GlobalConfigFile: parentDir + GlobalConfigFilename, PluginConfigFile: parentDir + PluginsConfigFilename}) config.InitConfig(ctx) assert.NotNil(t, config.Global) assert.NotEqual(t, GlobalConfig{}, config.Global) @@ -53,7 +54,8 @@ func TestInitConfig(t *testing.T) { // TestMergeGlobalConfig tests the MergeGlobalConfig function. func TestMergeGlobalConfig(t *testing.T) { ctx := context.Background() - config := NewConfig(ctx, parentDir+GlobalConfigFilename, parentDir+PluginsConfigFilename) + config := NewConfig(ctx, + Config{GlobalConfigFile: parentDir + GlobalConfigFilename, PluginConfigFile: parentDir + PluginsConfigFilename}) config.InitConfig(ctx) // The default log level is info. assert.Equal(t, DefaultLogLevel, config.Global.Loggers[Default].Level) diff --git a/config/getters_test.go b/config/getters_test.go index 0d81df36..f53a28a1 100644 --- a/config/getters_test.go +++ b/config/getters_test.go @@ -28,7 +28,8 @@ func TestGetDefaultConfigFilePath(t *testing.T) { // TestFilter tests the Filter function. func TestFilter(t *testing.T) { // Load config from the default config file. - conf := NewConfig(context.TODO(), "../gatewayd.yaml", "../gatewayd_plugins.yaml") + conf := NewConfig(context.TODO(), + Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) conf.InitConfig(context.TODO()) assert.NotEmpty(t, conf.Global) diff --git a/errors/errors.go b/errors/errors.go index 278d0284..b2f2a214 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -51,90 +51,127 @@ const ( ) var ( - ErrClientNotFound = NewGatewayDError( - ErrCodeClientNotFound, "client not found", nil) - ErrNilContext = NewGatewayDError( - ErrCodeNilContext, "context is nil", nil) - ErrClientNotConnected = NewGatewayDError( - ErrCodeClientNotConnected, "client is not connected", nil) - ErrClientConnectionFailed = NewGatewayDError( - ErrCodeClientConnectionFailed, "failed to create a new connection", nil) - ErrNetworkNotSupported = NewGatewayDError( - ErrCodeNetworkNotSupported, "network is not supported", nil) - ErrResolveFailed = NewGatewayDError( - ErrCodeResolveFailed, "failed to resolve address", nil) - ErrPoolExhausted = NewGatewayDError( - ErrCodePoolExhausted, "pool is exhausted", nil) + ErrClientNotFound = &GatewayDError{ + ErrCodeClientNotFound, "client not found", nil, + } + ErrNilContext = &GatewayDError{ + ErrCodeNilContext, "context is nil", nil, + } + ErrClientNotConnected = &GatewayDError{ + ErrCodeClientNotConnected, "client is not connected", nil, + } + ErrClientConnectionFailed = &GatewayDError{ + ErrCodeClientConnectionFailed, "failed to create a new connection", nil, + } + ErrNetworkNotSupported = &GatewayDError{ + ErrCodeNetworkNotSupported, "network is not supported", nil, + } + ErrResolveFailed = &GatewayDError{ + ErrCodeResolveFailed, "failed to resolve address", nil, + } + ErrPoolExhausted = &GatewayDError{ + ErrCodePoolExhausted, "pool is exhausted", nil, + } - ErrPluginNotReady = NewGatewayDError( - ErrCodePluginNotReady, "plugin is not ready", nil) - ErrFailedToStartPlugin = NewGatewayDError( - ErrCodeStartPluginFailed, "failed to start plugin", nil) - ErrFailedToGetRPCClient = NewGatewayDError( - ErrCodeGetRPCClientFailed, "failed to get RPC client", nil) - ErrFailedToDispensePlugin = NewGatewayDError( - ErrCodeDispensePluginFailed, "failed to dispense plugin", nil) - ErrFailedToMergePluginMetrics = NewGatewayDError( - ErrCodePluginMetricsMergeFailed, "failed to merge plugin metrics", nil) - ErrFailedToPingPlugin = NewGatewayDError( - ErrCodePluginPingFailed, "failed to ping plugin", nil) + ErrPluginNotReady = &GatewayDError{ + ErrCodePluginNotReady, "plugin is not ready", nil, + } + ErrFailedToStartPlugin = &GatewayDError{ + ErrCodeStartPluginFailed, "failed to start plugin", nil, + } + ErrFailedToGetRPCClient = &GatewayDError{ + ErrCodeGetRPCClientFailed, "failed to get RPC client", nil, + } + ErrFailedToDispensePlugin = &GatewayDError{ + ErrCodeDispensePluginFailed, "failed to dispense plugin", nil, + } + ErrFailedToMergePluginMetrics = &GatewayDError{ + ErrCodePluginMetricsMergeFailed, "failed to merge plugin metrics", nil, + } + ErrFailedToPingPlugin = &GatewayDError{ + ErrCodePluginPingFailed, "failed to ping plugin", nil, + } - ErrClientReceiveFailed = NewGatewayDError( - ErrCodeClientReceiveFailed, "couldn't receive data from the server", nil) - ErrClientSendFailed = NewGatewayDError( - ErrCodeClientSendFailed, "couldn't send data to the server", nil) + ErrClientReceiveFailed = &GatewayDError{ + ErrCodeClientReceiveFailed, "couldn't receive data from the server", nil, + } + ErrClientSendFailed = &GatewayDError{ + ErrCodeClientSendFailed, "couldn't send data to the server", nil, + } - ErrServerSendFailed = NewGatewayDError( - ErrCodeServerSendFailed, "couldn't send data to the client", nil) - ErrServerListenFailed = NewGatewayDError( - ErrCodeServerListenFailed, "couldn't listen on the server", nil) - ErrSplitHostPortFailed = NewGatewayDError( - ErrCodeSplitHostPortFailed, "failed to split host:port", nil) - ErrAcceptFailed = NewGatewayDError( - ErrCodeAcceptFailed, "failed to accept connection", nil) - ErrGetTLSConfigFailed = NewGatewayDError( - ErrCodeGetTLSConfigFailed, "failed to get TLS config", nil) - ErrUpgradeToTLSFailed = NewGatewayDError( - ErrCodeUpgradeToTLSFailed, "failed to upgrade to TLS", nil) + ErrServerSendFailed = &GatewayDError{ + ErrCodeServerSendFailed, "couldn't send data to the client", nil, + } + ErrServerListenFailed = &GatewayDError{ + ErrCodeServerListenFailed, "couldn't listen on the server", nil, + } + ErrSplitHostPortFailed = &GatewayDError{ + ErrCodeSplitHostPortFailed, "failed to split host:port", nil, + } + ErrAcceptFailed = &GatewayDError{ + ErrCodeAcceptFailed, "failed to accept connection", nil, + } + ErrGetTLSConfigFailed = &GatewayDError{ + ErrCodeGetTLSConfigFailed, "failed to get TLS config", nil, + } + ErrUpgradeToTLSFailed = &GatewayDError{ + ErrCodeUpgradeToTLSFailed, "failed to upgrade to TLS", nil, + } - ErrReadFailed = NewGatewayDError( - ErrCodeReadFailed, "failed to read from the client", nil) + ErrReadFailed = &GatewayDError{ + ErrCodeReadFailed, "failed to read from the client", nil, + } - ErrNilPointer = NewGatewayDError( - ErrCodeNilPointer, "nil pointer", nil) + ErrNilPointer = &GatewayDError{ + ErrCodeNilPointer, "nil pointer", nil, + } - ErrCastFailed = NewGatewayDError( - ErrCodeCastFailed, "failed to cast", nil) + ErrCastFailed = &GatewayDError{ + ErrCodeCastFailed, "failed to cast", nil, + } - ErrHookTerminatedConnection = NewGatewayDError( - ErrCodeHookTerminatedConnection, "hook terminated connection", nil) + ErrHookTerminatedConnection = &GatewayDError{ + ErrCodeHookTerminatedConnection, "hook terminated connection", nil, + } - ErrValidationFailed = NewGatewayDError( - ErrCodeValidationFailed, "validation failed", nil) - ErrLintingFailed = NewGatewayDError( - ErrCodeLintingFailed, "linting failed", nil) + ErrValidationFailed = &GatewayDError{ + ErrCodeValidationFailed, "validation failed", nil, + } + ErrLintingFailed = &GatewayDError{ + ErrCodeLintingFailed, "linting failed", nil, + } - ErrExtractFailed = NewGatewayDError( - ErrCodeExtractFailed, "failed to extract the archive", nil) - ErrDownloadFailed = NewGatewayDError( - ErrCodeDownloadFailed, "failed to download the file", nil) + ErrExtractFailed = &GatewayDError{ + ErrCodeExtractFailed, "failed to extract the archive", nil, + } + ErrDownloadFailed = &GatewayDError{ + ErrCodeDownloadFailed, "failed to download the file", nil, + } - ErrActionNotExist = NewGatewayDError( - ErrCodeKeyNotFound, "action does not exist", nil) - ErrRunningAction = NewGatewayDError( - ErrCodeRunError, "error running action", nil) - ErrAsyncAction = NewGatewayDError( - ErrCodeAsyncAction, "async action", nil) - ErrRunningActionTimeout = NewGatewayDError( - ErrCodeRunError, "timeout running action", nil) - ErrActionNotMatched = NewGatewayDError( - ErrCodeKeyNotFound, "no matching action", nil) - ErrPolicyNotMatched = NewGatewayDError( - ErrCodeKeyNotFound, "no matching policy", nil) - ErrEvalError = NewGatewayDError( - ErrCodeEvalError, "error evaluating expression", nil) - ErrMsgEncodeError = NewGatewayDError( - ErrCodeMsgEncodeError, "error encoding message", nil) + ErrActionNotExist = &GatewayDError{ + ErrCodeKeyNotFound, "action does not exist", nil, + } + ErrRunningAction = &GatewayDError{ + ErrCodeRunError, "error running action", nil, + } + ErrAsyncAction = &GatewayDError{ + ErrCodeAsyncAction, "async action", nil, + } + ErrRunningActionTimeout = &GatewayDError{ + ErrCodeRunError, "timeout running action", nil, + } + ErrActionNotMatched = &GatewayDError{ + ErrCodeKeyNotFound, "no matching action", nil, + } + ErrPolicyNotMatched = &GatewayDError{ + ErrCodeKeyNotFound, "no matching policy", nil, + } + ErrEvalError = &GatewayDError{ + ErrCodeEvalError, "error evaluating expression", nil, + } + ErrMsgEncodeError = &GatewayDError{ + ErrCodeMsgEncodeError, "error encoding message", nil, + } // Unwrapped errors. ErrLoggerRequired = errors.New("terminate action requires a logger parameter") diff --git a/errors/gatewayd_error.go b/errors/gatewayd_error.go index 1048e4b4..7461d4e9 100644 --- a/errors/gatewayd_error.go +++ b/errors/gatewayd_error.go @@ -10,15 +10,6 @@ type GatewayDError struct { OriginalError error } -// NewGatewayDError creates a new GatewayDError. -func NewGatewayDError(code ErrCode, message string, err error) *GatewayDError { - return &GatewayDError{ - Code: code, - Message: message, - OriginalError: err, - } -} - // Error returns the error message of the GatewayDError. func (e *GatewayDError) Error() string { if e.OriginalError == nil { diff --git a/errors/gatewayd_error_test.go b/errors/gatewayd_error_test.go deleted file mode 100644 index 631eb574..00000000 --- a/errors/gatewayd_error_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package errors - -import ( - "io" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestNewGatewayDError tests the creation of a new GatewayDError. -func TestNewGatewayDError(t *testing.T) { - err := NewGatewayDError(ErrCodeUnknown, "test", nil) - assert.NotNil(t, err) - assert.Equal(t, ErrCodeUnknown, err.Code) - assert.Equal(t, "test", err.Error()) - assert.Equal(t, "test", err.Message) - require.NoError(t, err.OriginalError) - - assert.NotNil(t, err.Wrap(io.EOF)) - assert.Equal(t, io.EOF, err.OriginalError) - assert.Equal(t, io.EOF, err.Unwrap()) - assert.Equal(t, "test, OriginalError: EOF", err.Error()) -} diff --git a/metrics/merger.go b/metrics/merger.go index 902b4e60..33c06364 100644 --- a/metrics/merger.go +++ b/metrics/merger.go @@ -48,7 +48,7 @@ var _ IMerger = (*Merger)(nil) // NewMerger creates a new metrics merger. func NewMerger( - ctx context.Context, metricsMergerPeriod time.Duration, logger zerolog.Logger, + ctx context.Context, merger Merger, ) *Merger { mergerCtx, span := otel.Tracer(config.TracerName).Start(ctx, "NewMerger") defer span.End() @@ -56,10 +56,10 @@ func NewMerger( return &Merger{ scheduler: gocron.NewScheduler(time.UTC), ctx: mergerCtx, - Logger: logger, + Logger: merger.Logger, Addresses: map[string]string{}, OutputMetrics: []byte{}, - MetricsMergerPeriod: metricsMergerPeriod, + MetricsMergerPeriod: merger.MetricsMergerPeriod, } } diff --git a/metrics/merger_test.go b/metrics/merger_test.go index a792ecdf..4cce31e9 100644 --- a/metrics/merger_test.go +++ b/metrics/merger_test.go @@ -28,7 +28,9 @@ func TestMerger(t *testing.T) { }, ) - merger := NewMerger(context.Background(), 1, logger) + merger := NewMerger(context.Background(), Merger{ + MetricsMergerPeriod: 1, Logger: logger, + }) merger.Add("test", "/tmp/test.sock") // We need to give the merger some time to read the metrics. diff --git a/network/conn_wrapper.go b/network/conn_wrapper.go index add7bf1e..b3c70e45 100644 --- a/network/conn_wrapper.go +++ b/network/conn_wrapper.go @@ -30,11 +30,11 @@ type IConnWrapper interface { } type ConnWrapper struct { - netConn net.Conn + NetConn net.Conn tlsConn *tls.Conn - tlsConfig *tls.Config + TLSConfig *tls.Config isTLSEnabled bool - handshakeTimeout time.Duration + HandshakeTimeout time.Duration } var _ IConnWrapper = (*ConnWrapper)(nil) @@ -44,7 +44,7 @@ func (cw *ConnWrapper) Conn() net.Conn { if cw.tlsConn != nil { return net.Conn(cw.tlsConn) } - return cw.netConn + return cw.NetConn } // UpgradeToTLS upgrades the connection to TLS. @@ -58,12 +58,12 @@ func (cw *ConnWrapper) UpgradeToTLS(upgrader UpgraderFunc) *gerr.GatewayDError { } if upgrader != nil { - upgrader(cw.netConn) + upgrader(cw.NetConn) } - tlsConn := tls.Server(cw.netConn, cw.tlsConfig) + tlsConn := tls.Server(cw.NetConn, cw.TLSConfig) - ctx, cancel := context.WithTimeout(context.Background(), cw.handshakeTimeout) + ctx, cancel := context.WithTimeout(context.Background(), cw.HandshakeTimeout) defer cancel() if err := tlsConn.HandshakeContext(ctx); err != nil { @@ -79,7 +79,7 @@ func (cw *ConnWrapper) Close() error { if cw.tlsConn != nil { return cw.tlsConn.Close() } - return cw.netConn.Close() + return cw.NetConn.Close() } // Write writes data to the connection. @@ -87,7 +87,7 @@ func (cw *ConnWrapper) Write(data []byte) (int, error) { if cw.tlsConn != nil { return cw.tlsConn.Write(data) } - return cw.netConn.Write(data) + return cw.NetConn.Write(data) } // Read reads data from the connection. @@ -95,7 +95,7 @@ func (cw *ConnWrapper) Read(data []byte) (int, error) { if cw.tlsConn != nil { return cw.tlsConn.Read(data) } - return cw.netConn.Read(data) + return cw.NetConn.Read(data) } // RemoteAddr returns the remote address. @@ -103,7 +103,7 @@ func (cw *ConnWrapper) RemoteAddr() net.Addr { if cw.tlsConn != nil { return cw.tlsConn.RemoteAddr() } - return cw.netConn.RemoteAddr() + return cw.NetConn.RemoteAddr() } // LocalAddr returns the local address. @@ -111,7 +111,7 @@ func (cw *ConnWrapper) LocalAddr() net.Addr { if cw.tlsConn != nil { return cw.tlsConn.LocalAddr() } - return cw.netConn.LocalAddr() + return cw.NetConn.LocalAddr() } // IsTLSEnabled returns true if TLS is enabled. @@ -122,13 +122,13 @@ func (cw *ConnWrapper) IsTLSEnabled() bool { // NewConnWrapper creates a new connection wrapper. The connection // wrapper is used to upgrade the connection to TLS if need be. func NewConnWrapper( - conn net.Conn, tlsConfig *tls.Config, handshakeTimeout time.Duration, + connWrapper ConnWrapper, ) *ConnWrapper { return &ConnWrapper{ - netConn: conn, - tlsConfig: tlsConfig, - isTLSEnabled: tlsConfig != nil && tlsConfig.Certificates != nil, - handshakeTimeout: handshakeTimeout, + NetConn: connWrapper.NetConn, + TLSConfig: connWrapper.TLSConfig, + isTLSEnabled: connWrapper.TLSConfig != nil && connWrapper.TLSConfig.Certificates != nil, + HandshakeTimeout: connWrapper.HandshakeTimeout, } } diff --git a/network/proxy.go b/network/proxy.go index 89e315dc..3a6f47d8 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -33,18 +33,18 @@ type IProxy interface { IsHealthy(cl *Client) (*Client, *gerr.GatewayDError) IsExhausted() bool Shutdown() - AvailableConnections() []string - BusyConnections() []string + AvailableConnectionsString() []string + BusyConnectionsString() []string } type Proxy struct { - availableConnections pool.IPool + AvailableConnections pool.IPool busyConnections pool.IPool - logger zerolog.Logger - pluginRegistry *plugin.Registry + Logger zerolog.Logger + PluginRegistry *plugin.Registry scheduler *gocron.Scheduler ctx context.Context //nolint:containedctx - pluginTimeout time.Duration + PluginTimeout time.Duration HealthCheckPeriod time.Duration // ClientConfig is used for reconnection @@ -56,24 +56,21 @@ var _ IProxy = (*Proxy)(nil) // NewProxy creates a new proxy. func NewProxy( ctx context.Context, - connPool pool.IPool, pluginRegistry *plugin.Registry, - healthCheckPeriod time.Duration, - clientConfig *config.Client, logger zerolog.Logger, - pluginTimeout time.Duration, + pxy Proxy, ) *Proxy { proxyCtx, span := otel.Tracer(config.TracerName).Start(ctx, "NewProxy") defer span.End() proxy := Proxy{ - availableConnections: connPool, + AvailableConnections: pxy.AvailableConnections, busyConnections: pool.NewPool(proxyCtx, config.EmptyPoolCapacity), - logger: logger, - pluginRegistry: pluginRegistry, + Logger: pxy.Logger, + PluginRegistry: pxy.PluginRegistry, scheduler: gocron.NewScheduler(time.UTC), ctx: proxyCtx, - pluginTimeout: pluginTimeout, - ClientConfig: clientConfig, - HealthCheckPeriod: healthCheckPeriod, + PluginTimeout: pxy.PluginTimeout, + ClientConfig: pxy.ClientConfig, + HealthCheckPeriod: pxy.HealthCheckPeriod, } startDelay := time.Now().Add(proxy.HealthCheckPeriod) @@ -81,52 +78,54 @@ func NewProxy( if _, err := proxy.scheduler.Every(proxy.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do( func() { now := time.Now() - logger.Trace().Msg("Running the client health check to recycle connection(s).") - proxy.availableConnections.ForEach(func(_, value interface{}) bool { + proxy.Logger.Trace().Msg("Running the client health check to recycle connection(s).") + proxy.AvailableConnections.ForEach(func(_, value interface{}) bool { if client, ok := value.(*Client); ok { // Connection is probably dead by now. - proxy.availableConnections.Remove(client.ID) + proxy.AvailableConnections.Remove(client.ID) client.Close() // Create a new client. client = NewClient( - proxyCtx, proxy.ClientConfig, proxy.logger, + proxyCtx, proxy.ClientConfig, proxy.Logger, NewRetry( - proxy.ClientConfig.Retries, - config.If( - proxy.ClientConfig.Backoff > 0, - proxy.ClientConfig.Backoff, - config.DefaultBackoff, - ), - proxy.ClientConfig.BackoffMultiplier, - proxy.ClientConfig.DisableBackoffCaps, - proxy.logger, + Retry{ + Retries: proxy.ClientConfig.Retries, + Backoff: config.If( + proxy.ClientConfig.Backoff > 0, + proxy.ClientConfig.Backoff, + config.DefaultBackoff, + ), + BackoffMultiplier: proxy.ClientConfig.BackoffMultiplier, + DisableBackoffCaps: proxy.ClientConfig.DisableBackoffCaps, + Logger: proxy.Logger, + }, ), ) if client != nil && client.ID != "" { - if err := proxy.availableConnections.Put(client.ID, client); err != nil { - proxy.logger.Err(err).Msg("Failed to update the client connection") + if err := proxy.AvailableConnections.Put(client.ID, client); err != nil { + proxy.Logger.Err(err).Msg("Failed to update the client connection") // Close the client, because we don't want to have orphaned connections. client.Close() } } else { - proxy.logger.Error().Msg("Failed to create a new client connection") + proxy.Logger.Error().Msg("Failed to create a new client connection") } } return true }) - logger.Trace().Str("duration", time.Since(now).String()).Msg( + proxy.Logger.Trace().Str("duration", time.Since(now).String()).Msg( "Finished the client health check") metrics.ProxyHealthChecks.Inc() }, ); err != nil { - proxy.logger.Error().Err(err).Msg("Failed to schedule the client health check") + proxy.Logger.Error().Err(err).Msg("Failed to schedule the client health check") sentry.CaptureException(err) span.RecordError(err) } // Start the scheduler. proxy.scheduler.StartAsync() - logger.Info().Fields( + proxy.Logger.Info().Fields( map[string]interface{}{ "startDelay": startDelay.Format(time.RFC3339), "healthCheckPeriod": proxy.HealthCheckPeriod.String(), @@ -144,7 +143,7 @@ func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError { var clientID string // Get the first available client from the pool. - pr.availableConnections.ForEach(func(key, _ interface{}) bool { + pr.AvailableConnections.ForEach(func(key, _ interface{}) bool { if cid, ok := key.(string); ok { clientID = cid return false // stop the loop. @@ -159,13 +158,13 @@ func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError { return gerr.ErrPoolExhausted } // Get the client from the pool with the given clientID. - if cl, ok := pr.availableConnections.Pop(clientID).(*Client); ok { + if cl, ok := pr.AvailableConnections.Pop(clientID).(*Client); ok { client = cl } client, err := pr.IsHealthy(client) if err != nil { - pr.logger.Error().Err(err).Msg("Failed to connect to the client") + pr.Logger.Error().Err(err).Msg("Failed to connect to the client") span.RecordError(err) } @@ -185,15 +184,15 @@ func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError { if client.ID != "" { fields["client"] = client.ID[:7] } - pr.logger.Debug().Fields(fields).Msg("Client has been assigned") + pr.Logger.Debug().Fields(fields).Msg("Client has been assigned") - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.connect", - "count": pr.availableConnections.Size(), + "count": pr.AvailableConnections.Size(), }, ).Msg("Available client connections") - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.connect", "count": pr.busyConnections.Size(), @@ -213,7 +212,7 @@ func (pr *Proxy) Disconnect(conn *ConnWrapper) *gerr.GatewayDError { if client == nil { // If this ever happens, it means that the client connection // is pre-empted from the busy connections pool. - pr.logger.Debug().Msg("Client connection is pre-empted from the busy connections pool") + pr.Logger.Debug().Msg("Client connection is pre-empted from the busy connections pool") span.RecordError(gerr.ErrClientNotFound) return gerr.ErrClientNotFound } @@ -221,32 +220,32 @@ func (pr *Proxy) Disconnect(conn *ConnWrapper) *gerr.GatewayDError { if client, ok := client.(*Client); ok { // Recycle the server connection by reconnecting. if err := client.Reconnect(); err != nil { - pr.logger.Error().Err(err).Msg("Failed to reconnect to the client") + pr.Logger.Error().Err(err).Msg("Failed to reconnect to the client") span.RecordError(err) } // If the client is not in the pool, put it back. - if err := pr.availableConnections.Put(client.ID, client); err != nil { - pr.logger.Error().Err(err).Msg("Failed to put the client back in the pool") + if err := pr.AvailableConnections.Put(client.ID, client); err != nil { + pr.Logger.Error().Err(err).Msg("Failed to put the client back in the pool") span.RecordError(err) } } else { // This should never happen, but if it does, // then there are some serious issues with the pool. - pr.logger.Error().Msg("Failed to cast the client to the Client type") + pr.Logger.Error().Msg("Failed to cast the client to the Client type") span.RecordError(gerr.ErrCastFailed) return gerr.ErrCastFailed } metrics.ProxiedConnections.Dec() - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.disconnect", - "count": pr.availableConnections.Size(), + "count": pr.AvailableConnections.Size(), }, ).Msg("Available client connections") - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.disconnect", "count": pr.busyConnections.Size(), @@ -286,10 +285,10 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate span.AddEvent("Received traffic from client") // Run the OnTrafficFromClient hooks. - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.PluginTimeout) defer cancel() - result, err := pr.pluginRegistry.Run( + result, err := pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -303,7 +302,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate origErr), v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_CLIENT) if err != nil { - pr.logger.Error().Err(err).Msg("Error running hook") + pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } span.AddEvent("Ran the OnTrafficFromClient hooks") @@ -322,10 +321,10 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate // Acknowledge the SSL request: // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SSL if sent, err := conn.Write([]byte{'S'}); err != nil { - pr.logger.Error().Err(err).Msg("Failed to acknowledge the SSL request") + pr.Logger.Error().Err(err).Msg("Failed to acknowledge the SSL request") span.RecordError(err) } else { - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "upgradeToTLS", "local": LocalAddr(conn.Conn()), @@ -335,13 +334,13 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate ).Msg("Sent data to database") } }); err != nil { - pr.logger.Error().Err(err).Msg("Failed to perform the TLS handshake") + pr.Logger.Error().Err(err).Msg("Failed to perform the TLS handshake") span.RecordError(err) } // Check if the TLS handshake was successful. if conn.IsTLSEnabled() { - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "local": LocalAddr(conn.Conn()), "remote": RemoteAddr(conn.Conn()), @@ -350,7 +349,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate span.AddEvent("Performed the TLS handshake") metrics.TLSConnections.Inc() } else { - pr.logger.Error().Fields( + pr.Logger.Error().Fields( map[string]interface{}{ "local": LocalAddr(conn.Conn()), "remote": RemoteAddr(conn.Conn()), @@ -365,7 +364,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate } else if !conn.IsTLSEnabled() && IsPostgresSSLRequest(request) { // Client sent a SSL request, but the server does not support SSL. - pr.logger.Warn().Fields( + pr.Logger.Warn().Fields( map[string]interface{}{ "local": LocalAddr(conn.Conn()), "remote": RemoteAddr(conn.Conn()), @@ -377,7 +376,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate // so we need to switch to a plaintext connection: // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SSL if _, err := conn.Write([]byte{'N'}); err != nil { - pr.logger.Warn().Err(err).Msg("Server does not support SSL, but SSL was required by the client") + pr.Logger.Warn().Err(err).Msg("Server does not support SSL, but SSL was required by the client") span.RecordError(err) } @@ -392,7 +391,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate // If the hook wants to terminate the connection, do it. if terminate, resp := pr.shouldTerminate(result); terminate { if resp != nil { - pr.logger.Trace().Fields( + pr.Logger.Trace().Fields( map[string]interface{}{ "function": "proxy.passthrough", "result": resp, @@ -431,11 +430,11 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate _, err = pr.sendTrafficToServer(client, request) span.AddEvent("Sent traffic to server") - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.pluginTimeout) + pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout) defer cancel() // Run the OnTrafficToServer hooks. - _, err = pr.pluginRegistry.Run( + _, err = pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -449,7 +448,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate err), v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_SERVER) if err != nil { - pr.logger.Error().Err(err).Msg("Error running hook") + pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } span.AddEvent("Ran the OnTrafficToServer hooks") @@ -497,7 +496,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate if client.RemoteAddr() != "" { fields["remoteAddr"] = client.RemoteAddr() } - pr.logger.Debug().Fields(fields).Msg("No data to send to client") + pr.Logger.Debug().Fields(fields).Msg("No data to send to client") span.AddEvent("No data to send to client") span.RecordError(err) @@ -506,7 +505,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate return err } - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.PluginTimeout) defer cancel() // Get the last request from the stack. @@ -517,7 +516,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate } // Run the OnTrafficFromServer hooks. - result, err := pr.pluginRegistry.Run( + result, err := pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -535,7 +534,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate err), v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_SERVER) if err != nil { - pr.logger.Error().Err(err).Msg("Error running hook") + pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } span.AddEvent("Ran the OnTrafficFromServer hooks") @@ -552,10 +551,10 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate span.AddEvent("Sent traffic to client") // Run the OnTrafficToClient hooks. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.pluginTimeout) + pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout) defer cancel() - _, err = pr.pluginRegistry.Run( + _, err = pr.PluginRegistry.Run( pluginTimeoutCtx, trafficData( conn.Conn(), @@ -574,7 +573,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate ), v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_CLIENT) if err != nil { - pr.logger.Error().Err(err).Msg("Error running hook") + pr.Logger.Error().Err(err).Msg("Error running hook") span.RecordError(err) } @@ -593,13 +592,13 @@ func (pr *Proxy) IsHealthy(client *Client) (*Client, *gerr.GatewayDError) { defer span.End() if pr.IsExhausted() { - pr.logger.Error().Msg("No more available connections") + pr.Logger.Error().Msg("No more available connections") span.RecordError(gerr.ErrPoolExhausted) return client, gerr.ErrPoolExhausted } if !client.IsConnected() { - pr.logger.Error().Msg("Client is disconnected") + pr.Logger.Error().Msg("Client is disconnected") span.RecordError(gerr.ErrClientNotConnected) } @@ -610,7 +609,7 @@ func (pr *Proxy) IsHealthy(client *Client) (*Client, *gerr.GatewayDError) { func (pr *Proxy) IsExhausted() bool { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsExhausted") defer span.End() - return pr.availableConnections.Size() == 0 && pr.availableConnections.Cap() > 0 + return pr.AvailableConnections.Size() == 0 && pr.AvailableConnections.Cap() > 0 } // Shutdown closes all connections and clears the connection pools. @@ -618,7 +617,7 @@ func (pr *Proxy) Shutdown() { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Shutdown") defer span.End() - pr.availableConnections.ForEach(func(_, value interface{}) bool { + pr.AvailableConnections.ForEach(func(_, value interface{}) bool { if client, ok := value.(*Client); ok { if client.IsConnected() { client.Close() @@ -626,18 +625,18 @@ func (pr *Proxy) Shutdown() { } return true }) - pr.availableConnections.Clear() - pr.logger.Debug().Msg("All available connections have been closed") + pr.AvailableConnections.Clear() + pr.Logger.Debug().Msg("All available connections have been closed") pr.busyConnections.ForEach(func(key, value interface{}) bool { if conn, ok := key.(net.Conn); ok { // This will stop all the Conn.Read() and Conn.Write() calls. if err := conn.SetDeadline(time.Now()); err != nil { - pr.logger.Error().Err(err).Msg("Error setting the deadline") + pr.Logger.Error().Err(err).Msg("Error setting the deadline") span.RecordError(err) } if err := conn.Close(); err != nil { - pr.logger.Error().Err(err).Msg("Failed to close the connection") + pr.Logger.Error().Err(err).Msg("Failed to close the connection") span.RecordError(err) } } @@ -651,16 +650,16 @@ func (pr *Proxy) Shutdown() { pr.busyConnections.Clear() pr.scheduler.Stop() pr.scheduler.Clear() - pr.logger.Debug().Msg("All busy connections have been closed") + pr.Logger.Debug().Msg("All busy connections have been closed") } -// AvailableConnections returns a list of available connections. -func (pr *Proxy) AvailableConnections() []string { +// AvailableConnectionsString returns a list of available connections. +func (pr *Proxy) AvailableConnectionsString() []string { _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "AvailableConnections") defer span.End() connections := make([]string, 0) - pr.availableConnections.ForEach(func(_, value interface{}) bool { + pr.AvailableConnections.ForEach(func(_, value interface{}) bool { if cl, ok := value.(*Client); ok { connections = append(connections, cl.LocalAddr()) } @@ -669,9 +668,9 @@ func (pr *Proxy) AvailableConnections() []string { return connections } -// BusyConnections returns a list of busy connections. -func (pr *Proxy) BusyConnections() []string { - _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "BusyConnections") +// BusyConnectionsString returns a list of busy connections. +func (pr *Proxy) BusyConnectionsString() []string { + _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "BusyConnectionsString") defer span.End() connections := make([]string, 0) @@ -696,7 +695,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize) read, err := conn.Read(chunk) if read == 0 || err != nil { - pr.logger.Debug().Err(err).Msg("Error reading from client") + pr.Logger.Debug().Err(err).Msg("Error reading from client") span.RecordError(err) metrics.BytesReceivedFromClient.Observe(float64(read)) @@ -718,7 +717,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD } length := len(buffer.Bytes()) - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "length": length, "local": LocalAddr(conn), @@ -740,17 +739,17 @@ func (pr *Proxy) sendTrafficToServer(client *Client, request []byte) (int, *gerr defer span.End() if len(request) == 0 { - pr.logger.Trace().Msg("Empty request") + pr.Logger.Trace().Msg("Empty request") return 0, nil } // Send the request to the server. sent, err := client.Send(request) if err != nil { - pr.logger.Error().Err(err).Msg("Error sending request to database") + pr.Logger.Error().Err(err).Msg("Error sending request to database") span.RecordError(err) } - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", "length": sent, @@ -786,7 +785,7 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga fields["remote"] = client.RemoteAddr() } - pr.logger.Debug().Fields(fields).Msg("Received data from database") + pr.Logger.Debug().Fields(fields).Msg("Received data from database") span.AddEvent("Received data from database") @@ -812,7 +811,7 @@ func (pr *Proxy) sendTrafficToClient( written, origErr := conn.Write(response[:received]) if origErr != nil { - pr.logger.Error().Err(origErr).Msg("Error writing to client") + pr.Logger.Error().Err(origErr).Msg("Error writing to client") span.RecordError(origErr) return gerr.ErrServerSendFailed.Wrap(origErr) } @@ -820,7 +819,7 @@ func (pr *Proxy) sendTrafficToClient( sent += written } - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", "length": sent, @@ -849,7 +848,7 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) if !ok { - pr.logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") + pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") return false, result } @@ -860,19 +859,19 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin if slices.Contains(keys, sdkAct.Terminal) { var actionResult map[string]interface{} for _, output := range outputs { - actRes, err := pr.pluginRegistry.ActRegistry().Run( + actRes, err := pr.PluginRegistry.ActRegistry.Run( output, act.WithResult(result)) // If the action is async and we received a sentinel error, // don't log the error. if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { - pr.logger.Error().Err(err).Msg("Error running policy") + pr.Logger.Error().Err(err).Msg("Error running policy") } // The terminate action should return a map. if v, ok := actRes.(map[string]interface{}); ok { actionResult = v } } - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", "reason": "terminate", @@ -892,7 +891,7 @@ func (pr *Proxy) getPluginModifiedRequest(result map[string]interface{}) []byte // If the hook modified the request, use the modified request. if modRequest, errMsg := extractFieldValue(result, "request"); errMsg != "" { - pr.logger.Error().Str("error", errMsg).Msg("Error in hook") + pr.Logger.Error().Str("error", errMsg).Msg("Error in hook") } else if modRequest != nil { return modRequest } @@ -908,7 +907,7 @@ func (pr *Proxy) getPluginModifiedResponse(result map[string]interface{}) ([]byt // If the hook returns a response, use it instead of the original response. if modResponse, errMsg := extractFieldValue(result, "response"); errMsg != "" { - pr.logger.Error().Str("error", errMsg).Msg("Error in hook") + pr.Logger.Error().Str("error", errMsg).Msg("Error in hook") } else if modResponse != nil { return modResponse, len(modResponse) } @@ -918,7 +917,7 @@ func (pr *Proxy) getPluginModifiedResponse(result map[string]interface{}) ([]byt func (pr *Proxy) isConnectionHealthy(conn net.Conn) bool { if n, err := conn.Read([]byte{}); n == 0 && err != nil { - pr.logger.Debug().Fields( + pr.Logger.Debug().Fields( map[string]interface{}{ "remote": RemoteAddr(conn), "local": LocalAddr(conn), diff --git a/network/proxy_test.go b/network/proxy_test.go index c4f6afe0..1e183e90 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -46,30 +46,40 @@ func TestNewProxy(t *testing.T) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - nil, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) defer proxy.Shutdown() assert.NotNil(t, proxy) assert.Equal(t, 0, proxy.busyConnections.Size(), "Proxy should have no connected clients") - assert.Equal(t, 1, proxy.availableConnections.Size()) - if c, ok := proxy.availableConnections.Pop(client.ID).(*Client); ok { + assert.Equal(t, 1, proxy.AvailableConnections.Size()) + if c, ok := proxy.AvailableConnections.Pop(client.ID).(*Client); ok { assert.NotEqual(t, "", c.ID) } assert.False(t, proxy.IsExhausted()) @@ -92,25 +102,35 @@ func BenchmarkNewProxy(b *testing.B) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool for i := 0; i < b.N; i++ { proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - nil, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) proxy.Shutdown() } } @@ -141,24 +161,35 @@ func BenchmarkProxyConnectDisconnect(b *testing.B) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - &clientConfig, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &clientConfig, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) defer proxy.Shutdown() conn := testConnection{} @@ -196,24 +227,35 @@ func BenchmarkProxyPassThrough(b *testing.B) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - &clientConfig, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &clientConfig, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) defer proxy.Shutdown() conn := testConnection{} @@ -256,24 +298,35 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - &clientConfig, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &clientConfig, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) defer proxy.Shutdown() conn := testConnection{} @@ -287,7 +340,7 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { } } -func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { +func BenchmarkProxyAvailableAndBusyConnectionsString(b *testing.B) { logger := logging.NewLogger(context.Background(), logging.LoggerConfig{ Output: []config.LogOutput{config.Console}, TimeFormat: zerolog.TimeFormatUnix, @@ -314,24 +367,35 @@ func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { // Create a new act registry actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) // Create a proxy with a fixed buffer newPool proxy := NewProxy( context.Background(), - newPool, - plugin.NewRegistry( - context.Background(), - actRegistry, - config.Loose, - logger, - false, - ), - config.DefaultHealthCheckPeriod, - &clientConfig, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: plugin.NewRegistry( + context.Background(), + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }, + ), + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &clientConfig, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) defer proxy.Shutdown() conn := testConnection{} @@ -340,7 +404,7 @@ func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { // Connect to the proxy for i := 0; i < b.N; i++ { - proxy.AvailableConnections() - proxy.BusyConnections() + proxy.AvailableConnectionsString() + proxy.BusyConnectionsString() } } diff --git a/network/retry.go b/network/retry.go index e0d6e604..1d206777 100644 --- a/network/retry.go +++ b/network/retry.go @@ -20,11 +20,11 @@ type IRetry interface { } type Retry struct { - logger zerolog.Logger Retries int Backoff time.Duration BackoffMultiplier float64 DisableBackoffCaps bool + Logger zerolog.Logger } var _ IRetry = (*Retry)(nil) @@ -77,14 +77,14 @@ func (r *Retry) Retry(callback RetryCallback) (any, error) { } if retry > 0 { - r.logger.Debug().Fields( + r.Logger.Debug().Fields( map[string]interface{}{ "retry": retry, "delay": backoffDuration.String(), }, ).Msg("Trying to run callback again") } else { - r.logger.Trace().Msg("First attempt to run callback") + r.Logger.Trace().Msg("First attempt to run callback") } // Try and retry the callback. @@ -96,24 +96,20 @@ func (r *Retry) Retry(callback RetryCallback) (any, error) { time.Sleep(backoffDuration) } - r.logger.Error().Err(err).Msgf("Failed to run callback after %d retries", retry) + r.Logger.Error().Err(err).Msgf("Failed to run callback after %d retries", retry) return nil, err } func NewRetry( - retries int, - backoff time.Duration, - backoffMultiplier float64, - disableBackoffCaps bool, - logger zerolog.Logger, + rty Retry, ) *Retry { retry := Retry{ - Retries: retries, - Backoff: backoff, - BackoffMultiplier: backoffMultiplier, - DisableBackoffCaps: disableBackoffCaps, - logger: logger, + Retries: rty.Retries, + Backoff: rty.Backoff, + BackoffMultiplier: rty.BackoffMultiplier, + DisableBackoffCaps: rty.DisableBackoffCaps, + Logger: rty.Logger, } // If the number of retries is less than 0, set it to 0 to disable retries. diff --git a/network/retry_test.go b/network/retry_test.go index 38d98a43..bd9a3e97 100644 --- a/network/retry_test.go +++ b/network/retry_test.go @@ -30,7 +30,7 @@ func TestRetry(t *testing.T) { assert.ErrorContains(t, err, "callback is nil") }) t.Run("retry without timeout", func(t *testing.T) { - retry := NewRetry(0, 0, 0, false, logger) + retry := NewRetry(Retry{0, 0, 0, false, logger}) assert.Equal(t, 0, retry.Retries) assert.Equal(t, time.Duration(0), retry.Backoff) assert.Equal(t, float64(0), retry.BackoffMultiplier) @@ -50,11 +50,13 @@ func TestRetry(t *testing.T) { }) t.Run("retry with timeout", func(t *testing.T) { retry := NewRetry( - config.DefaultRetries, - config.DefaultBackoff, - config.DefaultBackoffMultiplier, - config.DefaultDisableBackoffCaps, - logger, + Retry{ + config.DefaultRetries, + config.DefaultBackoff, + config.DefaultBackoffMultiplier, + config.DefaultDisableBackoffCaps, + logger, + }, ) assert.Equal(t, config.DefaultRetries, retry.Retries) assert.Equal(t, config.DefaultBackoff, retry.Backoff) diff --git a/network/server.go b/network/server.go index 00a293a0..7025f584 100644 --- a/network/server.go +++ b/network/server.go @@ -48,11 +48,11 @@ type IServer interface { } type Server struct { - proxy IProxy - logger zerolog.Logger - pluginRegistry *plugin.Registry + Proxy IProxy + Logger zerolog.Logger + PluginRegistry *plugin.Registry ctx context.Context //nolint:containedctx - pluginTimeout time.Duration + PluginTimeout time.Duration mu *sync.RWMutex Network string // tcp/udp/unix @@ -84,17 +84,17 @@ func (s *Server) OnBoot() Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnBoot") defer span.End() - s.logger.Debug().Msg("GatewayD is booting...") + s.Logger.Debug().Msg("GatewayD is booting...") - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnBooting hooks. - _, err := s.pluginRegistry.Run( + _, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]interface{}{"status": fmt.Sprint(s.Status)}, v1.HookName_HOOK_NAME_ON_BOOTING) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnBooting hook") + s.Logger.Error().Err(err).Msg("Failed to run OnBooting hook") span.RecordError(err) } span.AddEvent("Ran the OnBooting hooks") @@ -105,20 +105,20 @@ func (s *Server) OnBoot() Action { s.mu.Unlock() // Run the OnBooted hooks. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() - _, err = s.pluginRegistry.Run( + _, err = s.PluginRegistry.Run( pluginTimeoutCtx, map[string]interface{}{"status": fmt.Sprint(s.Status)}, v1.HookName_HOOK_NAME_ON_BOOTED) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnBooted hook") + s.Logger.Error().Err(err).Msg("Failed to run OnBooted hook") span.RecordError(err) } span.AddEvent("Ran the OnBooted hooks") - s.logger.Debug().Msg("GatewayD booted") + s.Logger.Debug().Msg("GatewayD booted") return None } @@ -129,10 +129,10 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnOpen") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(conn.Conn())).Msg( + s.Logger.Debug().Str("from", RemoteAddr(conn.Conn())).Msg( "GatewayD is opening a connection") - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnOpening hooks. onOpeningData := map[string]interface{}{ @@ -141,10 +141,10 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { "remote": RemoteAddr(conn.Conn()), }, } - _, err := s.pluginRegistry.Run( + _, err := s.PluginRegistry.Run( pluginTimeoutCtx, onOpeningData, v1.HookName_HOOK_NAME_ON_OPENING) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnOpening hook") + s.Logger.Error().Err(err).Msg("Failed to run OnOpening hook") span.RecordError(err) } span.AddEvent("Ran the OnOpening hooks") @@ -152,7 +152,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server // connections in the pool of the busy connections. - if err := s.proxy.Connect(conn); err != nil { + if err := s.Proxy.Connect(conn); err != nil { if errors.Is(err, gerr.ErrPoolExhausted) { span.RecordError(err) return nil, Close @@ -160,13 +160,13 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { // This should never happen. // TODO: Send error to client or retry connection - s.logger.Error().Err(err).Msg("Failed to connect to proxy") + s.Logger.Error().Err(err).Msg("Failed to connect to proxy") span.RecordError(err) return nil, None } // Run the OnOpened hooks. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() onOpenedData := map[string]interface{}{ @@ -175,10 +175,10 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { "remote": RemoteAddr(conn.Conn()), }, } - _, err = s.pluginRegistry.Run( + _, err = s.PluginRegistry.Run( pluginTimeoutCtx, onOpenedData, v1.HookName_HOOK_NAME_ON_OPENED) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnOpened hook") + s.Logger.Error().Err(err).Msg("Failed to run OnOpened hook") span.RecordError(err) } span.AddEvent("Ran the OnOpened hooks") @@ -194,11 +194,11 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnClose") defer span.End() - s.logger.Debug().Str("from", RemoteAddr(conn.Conn())).Msg( + s.Logger.Debug().Str("from", RemoteAddr(conn.Conn())).Msg( "GatewayD is closing a connection") // Run the OnClosing hooks. - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() data := map[string]interface{}{ @@ -211,10 +211,10 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr := s.pluginRegistry.Run( + _, gatewaydErr := s.PluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSING) if gatewaydErr != nil { - s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook") + s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook") span.RecordError(gatewaydErr) } span.AddEvent("Ran the OnClosing hooks") @@ -228,8 +228,8 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either // recycles or disconnects the connections. - if err := s.proxy.Disconnect(conn); err != nil { - s.logger.Error().Err(err).Msg("Failed to disconnect the server connection") + if err := s.Proxy.Disconnect(conn); err != nil { + s.Logger.Error().Err(err).Msg("Failed to disconnect the server connection") span.RecordError(err) return Close } @@ -240,13 +240,13 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { // Close the incoming connection. if err := conn.Close(); err != nil { - s.logger.Error().Err(err).Msg("Failed to close the incoming connection") + s.Logger.Error().Err(err).Msg("Failed to close the incoming connection") span.RecordError(err) return Close } // Run the OnClosed hooks. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() data = map[string]interface{}{ @@ -259,10 +259,10 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr = s.pluginRegistry.Run( + _, gatewaydErr = s.PluginRegistry.Run( pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSED) if gatewaydErr != nil { - s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook") + s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook") span.RecordError(gatewaydErr) } span.AddEvent("Ran the OnClosed hooks") @@ -279,7 +279,7 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti defer span.End() // Run the OnTraffic hooks. - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() onTrafficData := map[string]interface{}{ @@ -288,10 +288,10 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti "remote": RemoteAddr(conn.Conn()), }, } - _, err := s.pluginRegistry.Run( + _, err := s.PluginRegistry.Run( pluginTimeoutCtx, onTrafficData, v1.HookName_HOOK_NAME_ON_TRAFFIC) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnTraffic hook") + s.Logger.Error().Err(err).Msg("Failed to run OnTraffic hook") span.RecordError(err) } span.AddEvent("Ran the OnTraffic hooks") @@ -302,9 +302,9 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti // If there is an error, log it and close the connection. go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { - server.logger.Trace().Msg("Passing through traffic from client to server") - if err := server.proxy.PassThroughToServer(conn, stack); err != nil { - server.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.Logger.Trace().Msg("Passing through traffic from client to server") + if err := server.Proxy.PassThroughToServer(conn, stack); err != nil { + server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} break @@ -316,9 +316,9 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti // If there is an error, log it and close the connection. go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { - server.logger.Trace().Msg("Passing through traffic from server to client") - if err := server.proxy.PassThroughToClient(conn, stack); err != nil { - server.logger.Trace().Err(err).Msg("Failed to pass through traffic") + server.Logger.Trace().Msg("Passing through traffic from server to client") + if err := server.Proxy.PassThroughToClient(conn, stack); err != nil { + server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} break @@ -337,23 +337,23 @@ func (s *Server) OnShutdown() { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnShutdown") defer span.End() - s.logger.Debug().Msg("GatewayD is shutting down") + s.Logger.Debug().Msg("GatewayD is shutting down") - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnShutdown hooks. - _, err := s.pluginRegistry.Run( + _, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]interface{}{"connections": s.CountConnections()}, v1.HookName_HOOK_NAME_ON_SHUTDOWN) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnShutdown hook") + s.Logger.Error().Err(err).Msg("Failed to run OnShutdown hook") span.RecordError(err) } span.AddEvent("Ran the OnShutdown hooks") // Shutdown the proxy. - s.proxy.Shutdown() + s.Proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -366,19 +366,19 @@ func (s *Server) OnTick() (time.Duration, Action) { _, span := otel.Tracer("gatewayd").Start(s.ctx, "OnTick") defer span.End() - s.logger.Debug().Msg("GatewayD is ticking...") - s.logger.Info().Str("count", strconv.Itoa(s.CountConnections())).Msg( + s.Logger.Debug().Msg("GatewayD is ticking...") + s.Logger.Info().Str("count", strconv.Itoa(s.CountConnections())).Msg( "Active client connections") - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnTick hooks. - _, err := s.pluginRegistry.Run( + _, err := s.PluginRegistry.Run( pluginTimeoutCtx, map[string]interface{}{"connections": s.CountConnections()}, v1.HookName_HOOK_NAME_ON_TICK) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run OnTick hook") + s.Logger.Error().Err(err).Msg("Failed to run OnTick hook") span.RecordError(err) } span.AddEvent("Ran the OnTick hooks") @@ -397,16 +397,16 @@ func (s *Server) Run() *gerr.GatewayDError { _, span := otel.Tracer("gatewayd").Start(s.ctx, "Run") defer span.End() - s.logger.Info().Str("pid", strconv.Itoa(os.Getpid())).Msg("GatewayD is running") + s.Logger.Info().Str("pid", strconv.Itoa(os.Getpid())).Msg("GatewayD is running") // Try to resolve the address and log an error if it can't be resolved - addr, err := Resolve(s.Network, s.Address, s.logger) + addr, err := Resolve(s.Network, s.Address, s.Logger) if err != nil { - s.logger.Error().Err(err).Msg("Failed to resolve address") + s.Logger.Error().Err(err).Msg("Failed to resolve address") span.RecordError(err) } - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.pluginTimeout) + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() // Run the OnRun hooks. // Since Run is blocking, we need to run OnRun before it. @@ -414,17 +414,17 @@ func (s *Server) Run() *gerr.GatewayDError { if err != nil && err.Unwrap() != nil { onRunData["error"] = err.OriginalError.Error() } - result, err := s.pluginRegistry.Run( + result, err := s.PluginRegistry.Run( pluginTimeoutCtx, onRunData, v1.HookName_HOOK_NAME_ON_RUN) if err != nil { - s.logger.Error().Err(err).Msg("Failed to run the hook") + s.Logger.Error().Err(err).Msg("Failed to run the hook") span.RecordError(err) } span.AddEvent("Ran the OnRun hooks") if result != nil { if errMsg, ok := result["error"].(string); ok && errMsg != "" { - s.logger.Error().Str("error", errMsg).Msg("Error in hook") + s.Logger.Error().Str("error", errMsg).Msg("Error in hook") } if address, ok := result["address"].(string); ok { @@ -438,7 +438,7 @@ func (s *Server) Run() *gerr.GatewayDError { listener, origErr := net.Listen(s.Network, addr) if origErr != nil { - s.logger.Error().Err(origErr).Msg("Server failed to start listening") + s.Logger.Error().Err(origErr).Msg("Server failed to start listening") return gerr.ErrServerListenFailed.Wrap(origErr) } s.mu.Lock() @@ -447,26 +447,26 @@ func (s *Server) Run() *gerr.GatewayDError { defer s.listener.Close() if s.listener == nil { - s.logger.Error().Msg("Server is not properly initialized") + s.Logger.Error().Msg("Server is not properly initialized") return nil } var port string s.host, port, origErr = net.SplitHostPort(s.listener.Addr().String()) if origErr != nil { - s.logger.Error().Err(origErr).Msg("Failed to split host and port") + s.Logger.Error().Err(origErr).Msg("Failed to split host and port") return gerr.ErrSplitHostPortFailed.Wrap(origErr) } if s.port, origErr = strconv.Atoi(port); origErr != nil { - s.logger.Error().Err(origErr).Msg("Failed to convert port to integer") + s.Logger.Error().Err(origErr).Msg("Failed to convert port to integer") return gerr.ErrCastFailed.Wrap(origErr) } go func(server *Server) { <-server.stopServer server.OnShutdown() - server.logger.Debug().Msg("Server stopped") + server.Logger.Debug().Msg("Server stopped") }(s) go func(server *Server) { @@ -498,18 +498,18 @@ func (s *Server) Run() *gerr.GatewayDError { if s.EnableTLS { tlsConfig, origErr = CreateTLSConfig(s.CertFile, s.KeyFile) if origErr != nil { - s.logger.Error().Err(origErr).Msg("Failed to create TLS config") + s.Logger.Error().Err(origErr).Msg("Failed to create TLS config") return gerr.ErrGetTLSConfigFailed.Wrap(origErr) } - s.logger.Info().Msg("TLS is enabled") + s.Logger.Info().Msg("TLS is enabled") } else { - s.logger.Debug().Msg("TLS is disabled") + s.Logger.Debug().Msg("TLS is disabled") } for { select { case <-s.stopServer: - s.logger.Info().Msg("Server stopped") + s.Logger.Info().Msg("Server stopped") return nil default: netConn, err := s.listener.Accept() @@ -517,15 +517,19 @@ func (s *Server) Run() *gerr.GatewayDError { if !s.running.Load() { return nil } - s.logger.Error().Err(err).Msg("Failed to accept connection") + s.Logger.Error().Err(err).Msg("Failed to accept connection") return gerr.ErrAcceptFailed.Wrap(err) } - conn := NewConnWrapper(netConn, tlsConfig, s.HandshakeTimeout) + conn := NewConnWrapper(ConnWrapper{ + NetConn: netConn, + TLSConfig: tlsConfig, + HandshakeTimeout: s.HandshakeTimeout, + }) if out, action := s.OnOpen(conn); action != None { if _, err := conn.Write(out); err != nil { - s.logger.Error().Err(err).Msg("Failed to write to connection") + s.Logger.Error().Err(err).Msg("Failed to write to connection") } _ = conn.Close() if action == Shutdown { @@ -570,7 +574,7 @@ func (s *Server) Shutdown() { defer span.End() // Shutdown the proxy. - s.proxy.Shutdown() + s.Proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -582,22 +586,22 @@ func (s *Server) Shutdown() { s.running.Store(false) if s.listener != nil { if err = s.listener.Close(); err != nil { - s.logger.Error().Err(err).Msg("Failed to close listener") + s.Logger.Error().Err(err).Msg("Failed to close listener") } } else { - s.logger.Error().Msg("Listener is not initialized") + s.Logger.Error().Msg("Listener is not initialized") } select { case <-s.stopServer: - s.logger.Info().Msg("Server stopped") + s.Logger.Info().Msg("Server stopped") default: s.stopServer <- struct{}{} close(s.stopServer) } if err != nil { - s.logger.Error().Err(err).Msg("Failed to shutdown server") + s.Logger.Error().Err(err).Msg("Failed to shutdown server") span.RecordError(err) } } @@ -616,16 +620,7 @@ func (s *Server) IsRunning() bool { // NewServer creates a new server. func NewServer( ctx context.Context, - network, address string, - tickInterval time.Duration, - options Option, - proxy IProxy, - logger zerolog.Logger, - pluginRegistry *plugin.Registry, - pluginTimeout time.Duration, - enableTLS bool, - certFile, keyFile string, - handshakeTimeout time.Duration, + srv Server, ) *Server { serverCtx, span := otel.Tracer(config.TracerName).Start(ctx, "NewServer") defer span.End() @@ -633,19 +628,19 @@ func NewServer( // Create the server. server := Server{ ctx: serverCtx, - Network: network, - Address: address, - Options: options, - TickInterval: tickInterval, + Network: srv.Network, + Address: srv.Address, + Options: srv.Options, + TickInterval: srv.TickInterval, Status: config.Stopped, - EnableTLS: enableTLS, - CertFile: certFile, - KeyFile: keyFile, - HandshakeTimeout: handshakeTimeout, - proxy: proxy, - logger: logger, - pluginRegistry: pluginRegistry, - pluginTimeout: pluginTimeout, + EnableTLS: srv.EnableTLS, + CertFile: srv.CertFile, + KeyFile: srv.KeyFile, + HandshakeTimeout: srv.HandshakeTimeout, + Proxy: srv.Proxy, + Logger: srv.Logger, + PluginRegistry: srv.PluginRegistry, + PluginTimeout: srv.PluginTimeout, mu: &sync.RWMutex{}, connections: 0, running: &atomic.Bool{}, @@ -653,19 +648,19 @@ func NewServer( } // Try to resolve the address and log an error if it can't be resolved. - addr, err := Resolve(server.Network, server.Address, logger) + addr, err := Resolve(server.Network, server.Address, srv.Logger) if err != nil { - logger.Error().Err(err).Msg("Failed to resolve address") + srv.Logger.Error().Err(err).Msg("Failed to resolve address") span.AddEvent(err.Error()) } if addr != "" { server.Address = addr - logger.Debug().Str("address", addr).Msg("Resolved address") - logger.Info().Str("address", addr).Msg("GatewayD is listening") + srv.Logger.Debug().Str("address", addr).Msg("Resolved address") + srv.Logger.Info().Str("address", addr).Msg("GatewayD is listening") } else { - logger.Error().Msg("Failed to resolve address") - logger.Warn().Str("address", server.Address).Msg( + srv.Logger.Error().Msg("Failed to resolve address") + srv.Logger.Warn().Str("address", server.Address).Msg( "GatewayD is listening on an unresolved address") } diff --git a/network/server_test.go b/network/server_test.go index 9cc55731..b090e126 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -40,27 +40,34 @@ func TestRunServer(t *testing.T) { }) actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) pluginRegistry := plugin.NewRegistry( context.Background(), - actRegistry, - config.Loose, - logger, - false, - ) + plugin.Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }) pluginRegistry.AddHook(v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_CLIENT, 1, onIncomingTraffic) pluginRegistry.AddHook(v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_SERVER, 1, onIncomingTraffic) pluginRegistry.AddHook(v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_SERVER, 1, onOutgoingTraffic) pluginRegistry.AddHook(v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_CLIENT, 1, onOutgoingTraffic) - assert.NotNil(t, pluginRegistry.ActRegistry()) - assert.NotNil(t, pluginRegistry.ActRegistry().Signals) - assert.NotNil(t, pluginRegistry.ActRegistry().Policies) - assert.NotNil(t, pluginRegistry.ActRegistry().Actions) - assert.Equal(t, config.DefaultPolicy, pluginRegistry.ActRegistry().DefaultPolicy.Name) - assert.Equal(t, config.DefaultPolicy, pluginRegistry.ActRegistry().DefaultSignal.Name) + assert.NotNil(t, pluginRegistry.ActRegistry) + assert.NotNil(t, pluginRegistry.ActRegistry.Signals) + assert.NotNil(t, pluginRegistry.ActRegistry.Policies) + assert.NotNil(t, pluginRegistry.ActRegistry.Actions) + assert.Equal(t, config.DefaultPolicy, pluginRegistry.ActRegistry.DefaultPolicy.Name) + assert.Equal(t, config.DefaultPolicy, pluginRegistry.ActRegistry.DefaultSignal.Name) clientConfig := config.Client{ Network: "tcp", @@ -87,30 +94,32 @@ func TestRunServer(t *testing.T) { // Create a proxy with a fixed buffer newPool. proxy := NewProxy( context.Background(), - newPool, - pluginRegistry, - config.DefaultHealthCheckPeriod, - &clientConfig, - logger, - config.DefaultPluginTimeout) + Proxy{ + AvailableConnections: newPool, + PluginRegistry: pluginRegistry, + HealthCheckPeriod: config.DefaultHealthCheckPeriod, + ClientConfig: &clientConfig, + Logger: logger, + PluginTimeout: config.DefaultPluginTimeout, + }, + ) // Create a server. server := NewServer( context.Background(), - "tcp", - "127.0.0.1:15432", - config.DefaultTickInterval, - Option{ - EnableTicker: true, + Server{ + Network: "tcp", + Address: "127.0.0.1:15432", + TickInterval: config.DefaultTickInterval, + Options: Option{ + EnableTicker: true, + }, + Proxy: proxy, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, }, - proxy, - logger, - pluginRegistry, - config.DefaultPluginTimeout, - false, - "", - "", - config.DefaultHandshakeTimeout, ) assert.NotNil(t, server) assert.Zero(t, server.connections) @@ -177,7 +186,7 @@ func TestRunServer(t *testing.T) { // AuthenticationOk. assert.Equal(t, uint8(0x52), data[0]) - assert.Equal(t, 2, proxy.availableConnections.Size()) + assert.Equal(t, 2, proxy.AvailableConnections.Size()) assert.Equal(t, 1, proxy.busyConnections.Size()) // Terminate the connection. diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 23aabec1..07c7ebc5 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -50,7 +50,6 @@ type IRegistry interface { LoadPlugins(ctx context.Context, plugins []config.Plugin, startTimeout time.Duration) RegisterHooks(ctx context.Context, pluginID sdkPlugin.Identifier) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output, bool) - ActRegistry() *act.Registry // Hook management IHook @@ -58,10 +57,10 @@ type IRegistry interface { type Registry struct { plugins pool.IPool - actRegistry *act.Registry + ActRegistry *act.Registry hooks map[v1.HookName]map[sdkPlugin.Priority]sdkPlugin.Method ctx context.Context //nolint:containedctx - devMode bool + DevMode bool Logger zerolog.Logger Compatibility config.CompatibilityPolicy @@ -73,22 +72,19 @@ var _ IRegistry = (*Registry)(nil) // NewRegistry creates a new plugin registry. func NewRegistry( ctx context.Context, - actRegistry *act.Registry, - compatibility config.CompatibilityPolicy, - logger zerolog.Logger, - devMode bool, + registry Registry, ) *Registry { regCtx, span := otel.Tracer(config.TracerName).Start(ctx, "Create new registry") defer span.End() return &Registry{ plugins: pool.NewPool(regCtx, config.EmptyPoolCapacity), - actRegistry: actRegistry, hooks: map[v1.HookName]map[sdkPlugin.Priority]sdkPlugin.Method{}, + ActRegistry: registry.ActRegistry, ctx: regCtx, - devMode: devMode, - Logger: logger, - Compatibility: compatibility, + DevMode: registry.DevMode, + Logger: registry.Logger, + Compatibility: registry.Compatibility, } } @@ -365,7 +361,7 @@ func (reg *Registry) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output // Apply policies to the signals. // The outputs contain the verdicts of the policies and their metadata. // And using this list, the caller can take further actions. - outputs := applyPolicies(hookName, signals, reg.Logger, reg.ActRegistry()) + outputs := applyPolicies(hookName, signals, reg.Logger, reg.ActRegistry) // If no policies are found, return a default output. // Note: this should never happen, as the default policy is always loaded. @@ -437,7 +433,7 @@ func (reg *Registry) LoadPlugins( } var secureConfig *goplugin.SecureConfig - if !reg.devMode { + if !reg.DevMode { // Checksum of the plugin. if plugin.ID.Checksum == "" { reg.Logger.Debug().Str("name", plugin.ID.Name).Msg( @@ -725,10 +721,3 @@ func (reg *Registry) RegisterHooks(ctx context.Context, pluginID sdkPlugin.Ident reg.AddHook(hookName, pluginImpl.Priority, hookMethod) } } - -// ActRegistry returns the act registry. -func (reg *Registry) ActRegistry() *act.Registry { - _, span := otel.Tracer(config.TracerName).Start(reg.ctx, "ActRegistry") - defer span.End() - return reg.actRegistry -} diff --git a/plugin/plugin_registry_test.go b/plugin/plugin_registry_test.go index 08c541a6..9bb81801 100644 --- a/plugin/plugin_registry_test.go +++ b/plugin/plugin_registry_test.go @@ -27,15 +27,22 @@ func NewPluginRegistry(t *testing.T) *Registry { } logger := logging.NewLogger(context.Background(), cfg) actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) reg := NewRegistry( context.Background(), - actRegistry, - config.Loose, - logger, - false, - ) + Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }) return reg } @@ -129,15 +136,23 @@ func BenchmarkHookRun(b *testing.B) { } logger := logging.NewLogger(context.Background(), cfg) actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + reg := NewRegistry( context.Background(), - actRegistry, - config.Loose, - logger, - false, - ) + Registry{ + ActRegistry: actRegistry, + Compatibility: config.Loose, + Logger: logger, + }) hookFunction := func( _ context.Context, args *v1.Struct, _ ...grpc.CallOption, ) (*v1.Struct, error) { diff --git a/plugin/utils_test.go b/plugin/utils_test.go index bb3bf8f1..690c492e 100644 --- a/plugin/utils_test.go +++ b/plugin/utils_test.go @@ -91,9 +91,15 @@ func Test_getSignals_empty(t *testing.T) { func Test_applyPolicies(t *testing.T) { logger := zerolog.Logger{} actRegistry := act.NewActRegistry( - act.BuiltinSignals(), act.BuiltinPolicies(), act.BuiltinActions(), - config.DefaultPolicy, config.DefaultPolicyTimeout, config.DefaultActionTimeout, logger, - ) + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) output := applyPolicies( "onTrafficFromClient", []sdkAct.Signal{*sdkAct.Passthrough()}, logger, actRegistry)