diff --git a/api/api_helpers_test.go b/api/api_helpers_test.go index 9e383333..1a015ce5 100644 --- a/api/api_helpers_test.go +++ b/api/api_helpers_test.go @@ -49,7 +49,7 @@ func getAPIConfig() *API { context.Background(), network.Server{ Logger: logger, - Proxy: defaultProxy, + Proxies: []network.IProxy{defaultProxy}, PluginRegistry: pluginReg, PluginTimeout: config.DefaultPluginTimeout, Network: "tcp", diff --git a/api/api_test.go b/api/api_test.go index 692a8d36..13048aac 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -333,7 +333,7 @@ func TestGetServers(t *testing.T) { Options: network.Option{ EnableTicker: false, }, - Proxy: proxy, + Proxies: []network.IProxy{proxy}, Logger: zerolog.Logger{}, PluginRegistry: pluginRegistry, PluginTimeout: config.DefaultPluginTimeout, diff --git a/api/healthcheck_test.go b/api/healthcheck_test.go index 2aa2a5ce..0efd03f1 100644 --- a/api/healthcheck_test.go +++ b/api/healthcheck_test.go @@ -69,7 +69,7 @@ func Test_Healthchecker(t *testing.T) { Options: network.Option{ EnableTicker: false, }, - Proxy: proxy, + Proxies: []network.IProxy{proxy}, Logger: zerolog.Logger{}, PluginRegistry: pluginRegistry, PluginTimeout: config.DefaultPluginTimeout, diff --git a/cmd/run.go b/cmd/run.go index 089ba50d..d3efff31 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -871,6 +871,18 @@ var runCmd = &cobra.Command{ // Create and initialize servers. for name, cfg := range conf.Global.Servers { logger := loggers[name] + + var serverProxies []network.IProxy + for _, proxyName := range cfg.Proxies { + proxy, exists := proxies[proxyName] + if !exists { + // This may occur if a proxy referenced in the server configuration does not exist. + logger.Error().Str("proxyName", proxyName).Msg("failed to find proxy configuration") + return + } + serverProxies = append(serverProxies, proxy) + } + servers[name] = network.NewServer( runCtx, network.Server{ @@ -885,14 +897,15 @@ var runCmd = &cobra.Command{ // Can be used to send keepalive messages to the client. EnableTicker: cfg.EnableTicker, }, - Proxy: proxies[name], - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: conf.Plugin.Timeout, - EnableTLS: cfg.EnableTLS, - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - HandshakeTimeout: cfg.HandshakeTimeout, + Proxies: serverProxies, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, + LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, }, ) diff --git a/config/config.go b/config/config.go index 0903fe58..b197af3b 100644 --- a/config/config.go +++ b/config/config.go @@ -160,6 +160,8 @@ func (c *Config) LoadDefaults(ctx context.Context) *gerr.GatewayDError { CertFile: "", KeyFile: "", HandshakeTimeout: DefaultHandshakeTimeout, + Proxies: []string{Default}, + LoadBalancer: LoadBalancer{Strategy: DefaultLoadBalancerStrategy}, } c.globalDefaults = GlobalConfig{ @@ -413,7 +415,7 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { } var errors []*gerr.GatewayDError - configObjects := []string{"loggers", "metrics", "clients", "pools", "proxies", "servers"} + configObjects := []string{"loggers", "metrics", "servers"} sort.Strings(configObjects) var seenConfigObjects []string @@ -441,7 +443,9 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { seenConfigObjects = append(seenConfigObjects, "metrics") } + clientConfigGroups := make(map[string]bool) for configGroup := range globalConfig.Clients { + clientConfigGroups[configGroup] = true if globalConfig.Clients[configGroup] == nil { err := fmt.Errorf("\"clients.%s\" is nil or empty", configGroup) span.RecordError(err) @@ -449,10 +453,6 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { } } - if len(globalConfig.Clients) > 1 { - seenConfigObjects = append(seenConfigObjects, "clients") - } - for configGroup := range globalConfig.Pools { if globalConfig.Pools[configGroup] == nil { err := fmt.Errorf("\"pools.%s\" is nil or empty", configGroup) @@ -461,10 +461,6 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { } } - if len(globalConfig.Pools) > 1 { - seenConfigObjects = append(seenConfigObjects, "pools") - } - for configGroup := range globalConfig.Proxies { if globalConfig.Proxies[configGroup] == nil { err := fmt.Errorf("\"proxies.%s\" is nil or empty", configGroup) @@ -473,10 +469,6 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { } } - if len(globalConfig.Proxies) > 1 { - seenConfigObjects = append(seenConfigObjects, "proxies") - } - for configGroup := range globalConfig.Servers { if globalConfig.Servers[configGroup] == nil { err := fmt.Errorf("\"servers.%s\" is nil or empty", configGroup) @@ -489,6 +481,50 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { seenConfigObjects = append(seenConfigObjects, "servers") } + // ValidateClientsPoolsProxies checks if all configGroups in globalConfig.Pools and globalConfig.Proxies + // are referenced in globalConfig.Clients. + if len(globalConfig.Clients) != len(globalConfig.Pools) || len(globalConfig.Clients) != len(globalConfig.Proxies) { + err := goerrors.New("clients, pools, and proxies do not have the same number of objects") + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + + // Check if all proxies are referenced in client configuration + for configGroup := range globalConfig.Proxies { + if !clientConfigGroups[configGroup] { + err := fmt.Errorf(`"proxies.%s" not referenced in client configuration`, configGroup) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + } + + // Check if all pools are referenced in client configuration + for configGroup := range globalConfig.Pools { + if !clientConfigGroups[configGroup] { + err := fmt.Errorf(`"pools.%s" not referenced in client configuration`, configGroup) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + } + + // Each server configuration should have at least one proxy defined. + // Each proxy in the server configuration should be referenced in proxies configuration. + for serverName, server := range globalConfig.Servers { + if len(server.Proxies) == 0 { + err := fmt.Errorf(`"servers.%s" has no proxies defined`, serverName) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + continue + } + for _, proxyName := range server.Proxies { + if _, exists := c.globalDefaults.Proxies[proxyName]; !exists { + err := fmt.Errorf(`"servers.%s" references a non-existent proxy "%s"`, serverName, proxyName) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + } + } + sort.Strings(seenConfigObjects) if len(seenConfigObjects) > 0 && !reflect.DeepEqual(configObjects, seenConfigObjects) { diff --git a/config/constants.go b/config/constants.go index 4591de2c..dec79124 100644 --- a/config/constants.go +++ b/config/constants.go @@ -89,10 +89,11 @@ const ( DefaultHealthCheckPeriod = 60 * time.Second // This must match PostgreSQL authentication timeout. // Server constants. - DefaultListenNetwork = "tcp" - DefaultListenAddress = "0.0.0.0:15432" - DefaultTickInterval = 5 * time.Second - DefaultHandshakeTimeout = 5 * time.Second + DefaultListenNetwork = "tcp" + DefaultListenAddress = "0.0.0.0:15432" + DefaultTickInterval = 5 * time.Second + DefaultHandshakeTimeout = 5 * time.Second + DefaultLoadBalancerStrategy = "ROUND_ROBIN" // Utility constants. DefaultSeed = 1000 @@ -124,3 +125,8 @@ const ( DefaultRedisAddress = "localhost:6379" DefaultRedisChannel = "gatewayd-actions" ) + +// Load balancing strategies. +const ( + RoundRobinStrategy = "ROUND_ROBIN" +) diff --git a/config/types.go b/config/types.go index 7a49fd63..94809b1b 100644 --- a/config/types.go +++ b/config/types.go @@ -96,6 +96,10 @@ type Proxy struct { HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer"` } +type LoadBalancer struct { + Strategy string `json:"strategy"` +} + type Server struct { EnableTicker bool `json:"enableTicker"` TickInterval time.Duration `json:"tickInterval" jsonschema:"oneof_type=string;integer"` @@ -105,6 +109,8 @@ type Server struct { CertFile string `json:"certFile"` KeyFile string `json:"keyFile"` HandshakeTimeout time.Duration `json:"handshakeTimeout" jsonschema:"oneof_type=string;integer"` + Proxies []string `json:"proxies"` + LoadBalancer LoadBalancer `json:"loadBalancer"` } type API struct { diff --git a/errors/errors.go b/errors/errors.go index c9868159..f4d8bc43 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -53,6 +53,8 @@ const ( ErrCodeMsgEncodeError ErrCodeConfigParseError ErrCodePublishAsyncAction + ErrCodeLoadBalancerStrategyNotFound + ErrCodeNoProxiesAvailable ) var ( @@ -194,6 +196,14 @@ var ( ErrCodePublishAsyncAction, "error publishing async action", nil, } + ErrLoadBalancerStrategyNotFound = &GatewayDError{ + ErrCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil, + } + + ErrNoProxiesAvailable = &GatewayDError{ + ErrCodeNoProxiesAvailable, "No proxies available to select.", nil, + } + // Unwrapped errors. ErrLoggerRequired = errors.New("terminate action requires a logger parameter") ) diff --git a/gatewayd.yaml b/gatewayd.yaml index d407f1a7..88f20265 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -46,19 +46,43 @@ clients: backoff: 1s # duration backoffMultiplier: 2.0 # 0 means no backoff disableBackoffCaps: false + default-2: + network: tcp + address: localhost:5433 + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s # duration + receiveChunkSize: 8192 + receiveDeadline: 0s # duration, 0ms/0s means no deadline + receiveTimeout: 0s # duration, 0ms/0s means no timeout + sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + # Retry configuration + retries: 3 # 0 means no retry and fail immediately on the first attempt + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false pools: default: size: 10 + default-2: + size: 10 proxies: default: healthCheckPeriod: 60s # duration + default-2: + healthCheckPeriod: 60s # duration servers: default: network: tcp address: 0.0.0.0:15432 + proxies: + - "default" + - "default-2" + loadBalancer: + strategy: ROUND_ROBIN enableTicker: False tickInterval: 5s # duration enableTLS: False diff --git a/network/loadbalancer.go b/network/loadbalancer.go new file mode 100644 index 00000000..76c57d2b --- /dev/null +++ b/network/loadbalancer.go @@ -0,0 +1,19 @@ +package network + +import ( + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +type LoadBalancerStrategy interface { + NextProxy() (IProxy, *gerr.GatewayDError) +} + +func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) { + switch server.LoadbalancerStrategyName { + case config.RoundRobinStrategy: + return NewRoundRobin(server), nil + default: + return nil, gerr.ErrLoadBalancerStrategyNotFound + } +} diff --git a/network/loadbalancer_test.go b/network/loadbalancer_test.go new file mode 100644 index 00000000..1d588b2a --- /dev/null +++ b/network/loadbalancer_test.go @@ -0,0 +1,44 @@ +package network + +import ( + "errors" + "testing" + + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +// TestNewLoadBalancerStrategy tests the NewLoadBalancerStrategy function to ensure it correctly +// initializes the load balancer strategy based on the strategy name provided in the server configuration. +// It covers both valid and invalid strategy names. +func TestNewLoadBalancerStrategy(t *testing.T) { + serverValid := &Server{ + LoadbalancerStrategyName: config.RoundRobinStrategy, + Proxies: []IProxy{MockProxy{}}, + } + + // Test case 1: Valid strategy name + strategy, err := NewLoadBalancerStrategy(serverValid) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + _, ok := strategy.(*RoundRobin) + if !ok { + t.Errorf("Expected strategy to be of type RoundRobin") + } + + // Test case 2: InValid strategy name + serverInvalid := &Server{ + LoadbalancerStrategyName: "InvalidStrategy", + Proxies: []IProxy{MockProxy{}}, + } + + strategy, err = NewLoadBalancerStrategy(serverInvalid) + if !errors.Is(err, gerr.ErrLoadBalancerStrategyNotFound) { + t.Errorf("Expected ErrLoadBalancerStrategyNotFound, got %v", err) + } + if strategy != nil { + t.Errorf("Expected strategy to be nil for invalid strategy name") + } +} diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 618cade2..e8e44cf5 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -16,6 +17,11 @@ type WriteBuffer struct { msgStart int } +// MockProxy implements the IProxy interface for testing purposes. +type MockProxy struct { + name string +} + // writeStartupMsg writes a PostgreSQL startup message to the buffer. func writeStartupMsg(buf *WriteBuffer, user, database, appName string) { // Write startup message header @@ -154,3 +160,51 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { require.NoError(t, testutil.GatherAndCompare(prometheus.DefaultGatherer, strings.NewReader(want), metrics...)) } + +// Connect is a mock implementation of the Connect method in the IProxy interface. +func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// Disconnect is a mock implementation of the Disconnect method in the IProxy interface. +func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// PassThroughToServer is a mock implementation of the PassThroughToServer method in the IProxy interface. +func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// PassThroughToClient is a mock implementation of the PassThroughToClient method in the IProxy interface. +func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// IsHealthy is a mock implementation of the IsHealthy method in the IProxy interface. +func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) { + return nil, nil +} + +// IsExhausted is a mock implementation of the IsExhausted method in the IProxy interface. +func (m MockProxy) IsExhausted() bool { + return false +} + +// Shutdown is a mock implementation of the Shutdown method in the IProxy interface. +func (m MockProxy) Shutdown() {} + +// AvailableConnectionsString is a mock implementation of the AvailableConnectionsString method in the IProxy interface. +func (m MockProxy) AvailableConnectionsString() []string { + return nil +} + +// BusyConnectionsString is a mock implementation of the BusyConnectionsString method in the IProxy interface. +func (m MockProxy) BusyConnectionsString() []string { + return nil +} + +// GetName returns the name of the MockProxy. +func (m MockProxy) GetName() string { + return m.name +} diff --git a/network/round-robin.go b/network/round-robin.go new file mode 100644 index 00000000..0057f432 --- /dev/null +++ b/network/round-robin.go @@ -0,0 +1,26 @@ +package network + +import ( + "errors" + "sync/atomic" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +type RoundRobin struct { + proxies []IProxy + next atomic.Uint32 +} + +func NewRoundRobin(server *Server) *RoundRobin { + return &RoundRobin{proxies: server.Proxies} +} + +func (r *RoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) { + proxiesLen := uint32(len(r.proxies)) + if proxiesLen == 0 { + return nil, gerr.ErrNoProxiesAvailable.Wrap(errors.New("proxy list is empty")) + } + nextIndex := r.next.Add(1) + return r.proxies[nextIndex%proxiesLen], nil +} diff --git a/network/round-robin_test.go b/network/round-robin_test.go new file mode 100644 index 00000000..ec430dc9 --- /dev/null +++ b/network/round-robin_test.go @@ -0,0 +1,117 @@ +package network + +import ( + "math" + "sync" + "testing" +) + +// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes +// the round-robin load balancer with the expected number of proxies. +func TestNewRoundRobin(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + rr := NewRoundRobin(server) + + if len(rr.proxies) != len(proxies) { + t.Errorf("expected %d proxies, got %d", len(proxies), len(rr.proxies)) + } +} + +// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure +// that it returns proxies in the expected order. +func TestRoundRobin_NextProxy(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + roundRobin := NewRoundRobin(server) + + expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"} + + for testIndex, expected := range expectedOrder { + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("test %d: unexpected error from NextProxy: %v", testIndex, err) + } + mockProxy, ok := proxy.(MockProxy) + if !ok { + t.Fatalf("test %d: expected proxy of type MockProxy, got %T", testIndex, proxy) + } + if mockProxy.GetName() != expected { + t.Errorf("test %d: expected proxy name %s, got %s", testIndex, expected, mockProxy.GetName()) + } + } +} + +// TestRoundRobin_ConcurrentAccess tests the thread safety of the NextProxy method in the round-robin load balancer +// by invoking it concurrently from multiple goroutines and ensuring that the internal state is updated correctly. +func TestRoundRobin_ConcurrentAccess(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + roundRobin := NewRoundRobin(server) + + var waitGroup sync.WaitGroup + numGoroutines := 100 + waitGroup.Add(numGoroutines) + + for range numGoroutines { + go func() { + defer waitGroup.Done() + _, _ = roundRobin.NextProxy() + }() + } + + waitGroup.Wait() + nextIndex := roundRobin.next.Load() + if nextIndex != uint32(numGoroutines) { + t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex) + } +} + +// TestNextProxyOverflow verifies that the round-robin proxy selection correctly handles +// the overflow of the internal counter. It sets the counter to a value close to the maximum +// uint32 value and ensures that the proxy selection wraps around as expected when the +// counter overflows. +func TestNextProxyOverflow(t *testing.T) { + // Create a server with a few mock proxies + server := &Server{ + Proxies: []IProxy{ + &MockProxy{}, + &MockProxy{}, + &MockProxy{}, + }, + } + roundRobin := NewRoundRobin(server) + + // Set the next value to near the max uint32 value to force an overflow + roundRobin.next.Store(math.MaxUint32 - 1) + + // Call NextProxy multiple times to trigger the overflow + for range 4 { + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if proxy == nil { + t.Fatal("Expected a proxy, got nil") + } + } + + // After overflow, next value should wrap around + expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2 + actualNextValue := roundRobin.next.Load() + if actualNextValue != expectedNextValue { + t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue) + } +} diff --git a/network/server.go b/network/server.go index 7025f584..1d3c6b99 100644 --- a/network/server.go +++ b/network/server.go @@ -48,7 +48,7 @@ type IServer interface { } type Server struct { - Proxy IProxy + Proxies []IProxy Logger zerolog.Logger PluginRegistry *plugin.Registry ctx context.Context //nolint:containedctx @@ -73,6 +73,11 @@ type Server struct { connections uint32 running *atomic.Bool stopServer chan struct{} + + // loadbalancer + loadbalancerStrategy LoadBalancerStrategy + LoadbalancerStrategyName string + connectionToProxyMap map[*ConnWrapper]IProxy } var _ IServer = (*Server)(nil) @@ -149,10 +154,18 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { } span.AddEvent("Ran the OnOpening hooks") + // Attempt to retrieve the next proxy. + proxy, err := s.loadbalancerStrategy.NextProxy() + if err != nil { + span.RecordError(err) + s.Logger.Error().Err(err).Msg("failed to retrieve next proxy") + return nil, Close + } + // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server // connections in the pool of the busy connections. - if err := s.Proxy.Connect(conn); err != nil { + if err := proxy.Connect(conn); err != nil { if errors.Is(err, gerr.ErrPoolExhausted) { span.RecordError(err) return nil, Close @@ -165,6 +178,9 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { return nil, None } + // Assign connection to proxy + s.connectionToProxyMap[conn] = proxy + // Run the OnOpened hooks. pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() @@ -225,15 +241,27 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { span.AddEvent("Shutting down the server") return Shutdown } + + // Find the proxy associated with the given connection + proxy, exists := s.GetProxyForConnection(conn) + if !exists { + // Log an error and return Close if no matching proxy is found + s.Logger.Error().Msg("Failed to find proxy to disconnect it") + return Close + } + // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either // recycles or disconnects the connections. - if err := s.Proxy.Disconnect(conn); err != nil { + if err := proxy.Disconnect(conn); err != nil { s.Logger.Error().Err(err).Msg("Failed to disconnect the server connection") span.RecordError(err) return Close } + // remove a connection from proxy connention map + s.RemoveConnectionFromMap(conn) + if conn.IsTLSEnabled() { metrics.TLSConnections.Dec() } @@ -303,7 +331,16 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { server.Logger.Trace().Msg("Passing through traffic from client to server") - if err := server.Proxy.PassThroughToServer(conn, stack); err != nil { + + // Find the proxy associated with the given connection + proxy, exists := server.GetProxyForConnection(conn) + if !exists { + server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} + break + } + + if err := proxy.PassThroughToServer(conn, stack); err != nil { server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} @@ -317,7 +354,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { server.Logger.Trace().Msg("Passing through traffic from server to client") - if err := server.Proxy.PassThroughToClient(conn, stack); err != nil { + + // Find the proxy associated with the given connection + proxy, exists := server.GetProxyForConnection(conn) + if !exists { + server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} + break + } + if err := proxy.PassThroughToClient(conn, stack); err != nil { server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} @@ -352,8 +397,10 @@ func (s *Server) OnShutdown() { } span.AddEvent("Ran the OnShutdown hooks") - // Shutdown the proxy. - s.Proxy.Shutdown() + // Shutdown proxies. + for _, proxy := range s.Proxies { + proxy.Shutdown() + } // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -573,8 +620,10 @@ func (s *Server) Shutdown() { _, span := otel.Tracer("gatewayd").Start(s.ctx, "Shutdown") defer span.End() - // Shutdown the proxy. - s.Proxy.Shutdown() + for _, proxy := range s.Proxies { + // Shutdown the proxy. + proxy.Shutdown() + } // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -627,24 +676,26 @@ func NewServer( // Create the server. server := Server{ - ctx: serverCtx, - Network: srv.Network, - Address: srv.Address, - Options: srv.Options, - TickInterval: srv.TickInterval, - Status: config.Stopped, - EnableTLS: srv.EnableTLS, - CertFile: srv.CertFile, - KeyFile: srv.KeyFile, - HandshakeTimeout: srv.HandshakeTimeout, - Proxy: srv.Proxy, - Logger: srv.Logger, - PluginRegistry: srv.PluginRegistry, - PluginTimeout: srv.PluginTimeout, - mu: &sync.RWMutex{}, - connections: 0, - running: &atomic.Bool{}, - stopServer: make(chan struct{}), + ctx: serverCtx, + Network: srv.Network, + Address: srv.Address, + Options: srv.Options, + TickInterval: srv.TickInterval, + Status: config.Stopped, + EnableTLS: srv.EnableTLS, + CertFile: srv.CertFile, + KeyFile: srv.KeyFile, + HandshakeTimeout: srv.HandshakeTimeout, + Proxies: srv.Proxies, + Logger: srv.Logger, + PluginRegistry: srv.PluginRegistry, + PluginTimeout: srv.PluginTimeout, + mu: &sync.RWMutex{}, + connections: 0, + running: &atomic.Bool{}, + stopServer: make(chan struct{}), + connectionToProxyMap: make(map[*ConnWrapper]IProxy), + LoadbalancerStrategyName: srv.LoadbalancerStrategyName, } // Try to resolve the address and log an error if it can't be resolved. @@ -664,6 +715,12 @@ func NewServer( "GatewayD is listening on an unresolved address") } + st, err := NewLoadBalancerStrategy(&server) + if err != nil { + srv.Logger.Error().Err(err).Msg("Failed to create a loadbalancer strategy") + } + server.loadbalancerStrategy = st + return &server } @@ -673,3 +730,14 @@ func (s *Server) CountConnections() int { defer s.mu.RUnlock() return int(s.connections) } + +// GetProxyForConnection returns the proxy associated with the given connection. +func (s *Server) GetProxyForConnection(conn *ConnWrapper) (IProxy, bool) { + proxy, exists := s.connectionToProxyMap[conn] + return proxy, exists +} + +// RemoveConnectionFromMap removes the given connection from the connection-to-proxy map. +func (s *Server) RemoveConnectionFromMap(conn *ConnWrapper) { + delete(s.connectionToProxyMap, conn) +} diff --git a/network/server_test.go b/network/server_test.go index b090e126..c655afaa 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -114,11 +114,12 @@ func TestRunServer(t *testing.T) { Options: Option{ EnableTicker: true, }, - Proxy: proxy, - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: config.DefaultPluginTimeout, - HandshakeTimeout: config.DefaultHandshakeTimeout, + Proxies: []IProxy{proxy}, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, + LoadbalancerStrategyName: config.RoundRobinStrategy, }, ) assert.NotNil(t, server)