diff --git a/config/constants.go b/config/constants.go index 83d9dca4..dec79124 100644 --- a/config/constants.go +++ b/config/constants.go @@ -125,3 +125,8 @@ const ( DefaultRedisAddress = "localhost:6379" DefaultRedisChannel = "gatewayd-actions" ) + +// Load balancing strategies. +const ( + RoundRobinStrategy = "ROUND_ROBIN" +) diff --git a/errors/errors.go b/errors/errors.go index 82f9bb87..f4d8bc43 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -54,6 +54,7 @@ const ( ErrCodeConfigParseError ErrCodePublishAsyncAction ErrCodeLoadBalancerStrategyNotFound + ErrCodeNoProxiesAvailable ) var ( @@ -196,7 +197,11 @@ var ( } ErrLoadBalancerStrategyNotFound = &GatewayDError{ - ErrorCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil, + ErrCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil, + } + + ErrNoProxiesAvailable = &GatewayDError{ + ErrCodeNoProxiesAvailable, "No proxies available to select.", nil, } // Unwrapped errors. diff --git a/gatewayd.yaml b/gatewayd.yaml index 65356412..88f20265 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -83,14 +83,6 @@ servers: - "default-2" loadBalancer: strategy: ROUND_ROBIN - # Not yet implemented. - # loadBalancingRules: - # - condition: "default" - # percentages: - # - proxy: "default" - # percentage: 70 - # - proxy: "default-2" - # percentage: 30 enableTicker: False tickInterval: 5s # duration enableTLS: False diff --git a/network/constants.go b/network/constants.go deleted file mode 100644 index d9d916e7..00000000 --- a/network/constants.go +++ /dev/null @@ -1,6 +0,0 @@ -package network - -// Load balancing strategies. -const ( - RoundRobinStrategy = "ROUND_ROBIN" -) diff --git a/network/loadbalancer.go b/network/loadbalancer.go index a572bfe5..76c57d2b 100644 --- a/network/loadbalancer.go +++ b/network/loadbalancer.go @@ -1,16 +1,17 @@ package network import ( + "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" ) type LoadBalancerStrategy interface { - GetNextProxy() IProxy + NextProxy() (IProxy, *gerr.GatewayDError) } func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) { switch server.LoadbalancerStrategyName { - case RoundRobinStrategy: + case config.RoundRobinStrategy: return NewRoundRobin(server), nil default: return nil, gerr.ErrLoadBalancerStrategyNotFound diff --git a/network/loadbalancer_test.go b/network/loadbalancer_test.go index ab988631..1d588b2a 100644 --- a/network/loadbalancer_test.go +++ b/network/loadbalancer_test.go @@ -4,12 +4,16 @@ import ( "errors" "testing" + "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" ) +// TestNewLoadBalancerStrategy tests the NewLoadBalancerStrategy function to ensure it correctly +// initializes the load balancer strategy based on the strategy name provided in the server configuration. +// It covers both valid and invalid strategy names. func TestNewLoadBalancerStrategy(t *testing.T) { serverValid := &Server{ - LoadbalancerStrategyName: RoundRobinStrategy, + LoadbalancerStrategyName: config.RoundRobinStrategy, Proxies: []IProxy{MockProxy{}}, } diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 618cade2..e8e44cf5 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -16,6 +17,11 @@ type WriteBuffer struct { msgStart int } +// MockProxy implements the IProxy interface for testing purposes. +type MockProxy struct { + name string +} + // writeStartupMsg writes a PostgreSQL startup message to the buffer. func writeStartupMsg(buf *WriteBuffer, user, database, appName string) { // Write startup message header @@ -154,3 +160,51 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { require.NoError(t, testutil.GatherAndCompare(prometheus.DefaultGatherer, strings.NewReader(want), metrics...)) } + +// Connect is a mock implementation of the Connect method in the IProxy interface. +func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// Disconnect is a mock implementation of the Disconnect method in the IProxy interface. +func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// PassThroughToServer is a mock implementation of the PassThroughToServer method in the IProxy interface. +func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// PassThroughToClient is a mock implementation of the PassThroughToClient method in the IProxy interface. +func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// IsHealthy is a mock implementation of the IsHealthy method in the IProxy interface. +func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) { + return nil, nil +} + +// IsExhausted is a mock implementation of the IsExhausted method in the IProxy interface. +func (m MockProxy) IsExhausted() bool { + return false +} + +// Shutdown is a mock implementation of the Shutdown method in the IProxy interface. +func (m MockProxy) Shutdown() {} + +// AvailableConnectionsString is a mock implementation of the AvailableConnectionsString method in the IProxy interface. +func (m MockProxy) AvailableConnectionsString() []string { + return nil +} + +// BusyConnectionsString is a mock implementation of the BusyConnectionsString method in the IProxy interface. +func (m MockProxy) BusyConnectionsString() []string { + return nil +} + +// GetName returns the name of the MockProxy. +func (m MockProxy) GetName() string { + return m.name +} diff --git a/network/round-robin.go b/network/round-robin.go index a18e6ca8..0057f432 100644 --- a/network/round-robin.go +++ b/network/round-robin.go @@ -1,6 +1,11 @@ package network -import "sync/atomic" +import ( + "errors" + "sync/atomic" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) type RoundRobin struct { proxies []IProxy @@ -11,10 +16,11 @@ func NewRoundRobin(server *Server) *RoundRobin { return &RoundRobin{proxies: server.Proxies} } -func (r *RoundRobin) GetNextProxy() IProxy { +func (r *RoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) { proxiesLen := uint32(len(r.proxies)) - + if proxiesLen == 0 { + return nil, gerr.ErrNoProxiesAvailable.Wrap(errors.New("proxy list is empty")) + } nextIndex := r.next.Add(1) - - return r.proxies[nextIndex%proxiesLen] + return r.proxies[nextIndex%proxiesLen], nil } diff --git a/network/round-robin_test.go b/network/round-robin_test.go index 0c00a309..ec430dc9 100644 --- a/network/round-robin_test.go +++ b/network/round-robin_test.go @@ -1,36 +1,13 @@ package network import ( + "math" "sync" "testing" - - gerr "github.com/gatewayd-io/gatewayd/errors" ) -// MockProxy implements IProxy interface for testing. -type MockProxy struct { - name string -} - -func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError { return nil } -func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError { return nil } -func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { - return nil -} - -func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { - return nil -} -func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) { return nil, nil } -func (m MockProxy) IsExhausted() bool { return false } -func (m MockProxy) Shutdown() {} -func (m MockProxy) AvailableConnectionsString() []string { return nil } -func (m MockProxy) BusyConnectionsString() []string { return nil } - -func (m MockProxy) GetName() string { - return m.name -} - +// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes +// the round-robin load balancer with the expected number of proxies. func TestNewRoundRobin(t *testing.T) { proxies := []IProxy{ MockProxy{name: "proxy1"}, @@ -45,7 +22,9 @@ func TestNewRoundRobin(t *testing.T) { } } -func TestRoundRobin_GetNextProxy(t *testing.T) { +// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure +// that it returns proxies in the expected order. +func TestRoundRobin_NextProxy(t *testing.T) { proxies := []IProxy{ MockProxy{name: "proxy1"}, MockProxy{name: "proxy2"}, @@ -57,7 +36,10 @@ func TestRoundRobin_GetNextProxy(t *testing.T) { expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"} for testIndex, expected := range expectedOrder { - proxy := roundRobin.GetNextProxy() + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("test %d: unexpected error from NextProxy: %v", testIndex, err) + } mockProxy, ok := proxy.(MockProxy) if !ok { t.Fatalf("test %d: expected proxy of type MockProxy, got %T", testIndex, proxy) @@ -68,6 +50,8 @@ func TestRoundRobin_GetNextProxy(t *testing.T) { } } +// TestRoundRobin_ConcurrentAccess tests the thread safety of the NextProxy method in the round-robin load balancer +// by invoking it concurrently from multiple goroutines and ensuring that the internal state is updated correctly. func TestRoundRobin_ConcurrentAccess(t *testing.T) { proxies := []IProxy{ MockProxy{name: "proxy1"}, @@ -84,7 +68,7 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) { for range numGoroutines { go func() { defer waitGroup.Done() - _ = roundRobin.GetNextProxy() + _, _ = roundRobin.NextProxy() }() } @@ -94,3 +78,40 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) { t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex) } } + +// TestNextProxyOverflow verifies that the round-robin proxy selection correctly handles +// the overflow of the internal counter. It sets the counter to a value close to the maximum +// uint32 value and ensures that the proxy selection wraps around as expected when the +// counter overflows. +func TestNextProxyOverflow(t *testing.T) { + // Create a server with a few mock proxies + server := &Server{ + Proxies: []IProxy{ + &MockProxy{}, + &MockProxy{}, + &MockProxy{}, + }, + } + roundRobin := NewRoundRobin(server) + + // Set the next value to near the max uint32 value to force an overflow + roundRobin.next.Store(math.MaxUint32 - 1) + + // Call NextProxy multiple times to trigger the overflow + for range 4 { + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if proxy == nil { + t.Fatal("Expected a proxy, got nil") + } + } + + // After overflow, next value should wrap around + expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2 + actualNextValue := roundRobin.next.Load() + if actualNextValue != expectedNextValue { + t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue) + } +} diff --git a/network/server.go b/network/server.go index ac6bf98b..1d3c6b99 100644 --- a/network/server.go +++ b/network/server.go @@ -154,8 +154,13 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { } span.AddEvent("Ran the OnOpening hooks") - // get proxy - proxy := s.loadbalancerStrategy.GetNextProxy() + // Attempt to retrieve the next proxy. + proxy, err := s.loadbalancerStrategy.NextProxy() + if err != nil { + span.RecordError(err) + s.Logger.Error().Err(err).Msg("failed to retrieve next proxy") + return nil, Close + } // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server @@ -173,7 +178,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { return nil, None } - // AssignConnectionToProxy + // Assign connection to proxy s.connectionToProxyMap[conn] = proxy // Run the OnOpened hooks. @@ -331,6 +336,7 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti proxy, exists := server.GetProxyForConnection(conn) if !exists { server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} break } @@ -353,6 +359,7 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti proxy, exists := server.GetProxyForConnection(conn) if !exists { server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} break } if err := proxy.PassThroughToClient(conn, stack); err != nil { diff --git a/network/server_test.go b/network/server_test.go index 94645208..c655afaa 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -119,7 +119,7 @@ func TestRunServer(t *testing.T) { PluginRegistry: pluginRegistry, PluginTimeout: config.DefaultPluginTimeout, HandshakeTimeout: config.DefaultHandshakeTimeout, - LoadbalancerStrategyName: RoundRobinStrategy, + LoadbalancerStrategyName: config.RoundRobinStrategy, }, ) assert.NotNil(t, server)