From 54d0c695656f73f775265438e340ce31526f4230 Mon Sep 17 00:00:00 2001 From: sina Date: Mon, 19 Aug 2024 16:25:54 +0200 Subject: [PATCH] Add Consistent Hashing Load Balancing Strategy (#592) * Added ConsistentHash * changed net conn into Iconnwrapper * added test cases for onsistentHash * fixed lint issues * replace RWMutex to Mutex and added test case for concurency access * added github.com/spaolacci/murmur3 into depguard --- .golangci.yaml | 1 + cmd/run.go | 21 ++-- config/types.go | 5 + gatewayd.yaml | 2 + go.mod | 2 + go.sum | 3 + network/consistenthash.go | 87 ++++++++++++++++ network/consistenthash_test.go | 154 +++++++++++++++++++++++++++++ network/loadbalancer.go | 16 ++- network/network_helpers_test.go | 67 +++++++++++++ network/random.go | 2 +- network/random_test.go | 10 +- network/roundrobin.go | 2 +- network/roundrobin_test.go | 6 +- network/server.go | 54 +++++----- network/weightedroundrobin.go | 2 +- network/weightedroundrobin_test.go | 4 +- 17 files changed, 385 insertions(+), 53 deletions(-) create mode 100644 network/consistenthash.go create mode 100644 network/consistenthash_test.go diff --git a/.golangci.yaml b/.golangci.yaml index 28438ba8..d1b2553e 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -65,6 +65,7 @@ linters-settings: - "golang.org/x/text/cases" - "golang.org/x/text/language" - "github.com/redis/go-redis/v9" + - "github.com/spaolacci/murmur3" test: files: - $test diff --git a/cmd/run.go b/cmd/run.go index ff35644f..31bfe75b 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -910,16 +910,17 @@ var runCmd = &cobra.Command{ // Can be used to send keepalive messages to the client. EnableTicker: cfg.EnableTicker, }, - Proxies: serverProxies, - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: conf.Plugin.Timeout, - EnableTLS: cfg.EnableTLS, - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - HandshakeTimeout: cfg.HandshakeTimeout, - LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, - LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, + Proxies: serverProxies, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, + LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, + LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, + LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash, }, ) diff --git a/config/types.go b/config/types.go index 47e296c5..797bdf96 100644 --- a/config/types.go +++ b/config/types.go @@ -106,9 +106,14 @@ type LoadBalancingRule struct { Distribution []Distribution `json:"distribution"` } +type ConsistentHash struct { + UseSourceIP bool `json:"useSourceIp"` +} + type LoadBalancer struct { Strategy string `json:"strategy"` LoadBalancingRules []LoadBalancingRule `json:"loadBalancingRules"` + ConsistentHash *ConsistentHash `json:"consistentHash,omitempty"` } type Server struct { diff --git a/gatewayd.yaml b/gatewayd.yaml index dbf81067..fbb8452e 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -84,6 +84,8 @@ servers: loadBalancer: # Load balancer strategies can be found in config/constants.go strategy: ROUND_ROBIN # ROUND_ROBIN, RANDOM, WEIGHTED_ROUND_ROBIN + consistentHash: + useSourceIp: true # Optional configuration for strategies that support rules (e.g., WEIGHTED_ROUND_ROBIN) # loadBalancingRules: # - condition: "DEFAULT" # Currently, only the "DEFAULT" condition is supported diff --git a/go.mod b/go.mod index fe69fd0e..ef680e9c 100644 --- a/go.mod +++ b/go.mod @@ -126,8 +126,10 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/skeema/knownhosts v1.2.1 // indirect + github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tetratelabs/wazero v1.7.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 30db0f5d..e4f0c611 100644 --- a/go.sum +++ b/go.sum @@ -477,6 +477,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= @@ -489,6 +491,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/network/consistenthash.go b/network/consistenthash.go new file mode 100644 index 00000000..0eef8419 --- /dev/null +++ b/network/consistenthash.go @@ -0,0 +1,87 @@ +package network + +import ( + "fmt" + "net" + "sync" + + gerr "github.com/gatewayd-io/gatewayd/errors" + "github.com/spaolacci/murmur3" +) + +// ConsistentHash implements a load balancing strategy based on consistent hashing. +// It routes client connections to specific proxies by hashing the client's IP address or the full connection address. +type ConsistentHash struct { + originalStrategy LoadBalancerStrategy + useSourceIP bool + hashMap map[uint64]IProxy + mu sync.Mutex +} + +// NewConsistentHash creates a new ConsistentHash instance. It requires a server configuration and an original +// load balancing strategy. The consistent hash can use either the source IP or the full connection address +// as the key for hashing. +func NewConsistentHash(server *Server, originalStrategy LoadBalancerStrategy) *ConsistentHash { + return &ConsistentHash{ + originalStrategy: originalStrategy, + useSourceIP: server.LoadbalancerConsistentHash.UseSourceIP, + hashMap: make(map[uint64]IProxy), + } +} + +// NextProxy selects the appropriate proxy for a given client connection. It first tries to find an existing +// proxy in the hash map based on the hashed key (either the source IP or the full address). If no match is found, +// it falls back to the original load balancing strategy, adds the selected proxy to the hash map, and returns it. +func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError) { + ch.mu.Lock() + defer ch.mu.Unlock() + + var key string + + if ch.useSourceIP { + sourceIP, err := extractIPFromConn(conn) + if err != nil { + return nil, gerr.ErrNoProxiesAvailable.Wrap(err) + } + key = sourceIP + } else { + key = conn.LocalAddr().String() // Fallback to use full address as the key if `useSourceIp` is false + } + + hash := hashKey(key) + + proxy, exists := ch.hashMap[hash] + + if exists { + return proxy, nil + } + + // If no hash exists, fallback to the original strategy + proxy, err := ch.originalStrategy.NextProxy(conn) + if err != nil { + return nil, gerr.ErrNoProxiesAvailable.Wrap(err) + } + + // Add the selected proxy to the hash map for future requests + ch.hashMap[hash] = proxy + + return proxy, nil +} + +// hashKey hashes a given key using the MurmurHash3 algorithm. It is used to generate consistent hash values +// for IP addresses or connection strings. +func hashKey(key string) uint64 { + return murmur3.Sum64([]byte(key)) +} + +// extractIPFromConn extracts the IP address from the connection's local address. It splits the address +// into IP and port components and returns the IP part. This is useful for hashing based on the source IP. +func extractIPFromConn(con IConnWrapper) (string, error) { + addr := con.LocalAddr().String() + // addr will be in the format "IP:port" + ip, _, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("failed to split host and port from address %s: %w", addr, err) + } + return ip, nil +} diff --git a/network/consistenthash_test.go b/network/consistenthash_test.go new file mode 100644 index 00000000..63e9a99f --- /dev/null +++ b/network/consistenthash_test.go @@ -0,0 +1,154 @@ +package network + +import ( + "net" + "sync" + "testing" + + "github.com/gatewayd-io/gatewayd/config" + "github.com/stretchr/testify/assert" +) + +// TestNewConsistentHash verifies that a new ConsistentHash instance is properly created. +// It checks that the original load balancing strategy is preserved, that the useSourceIp +// setting is correctly applied, and that the hashMap is initialized. +func TestNewConsistentHash(t *testing.T) { + server := &Server{ + LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true}, + } + originalStrategy := NewRandom(server) + consistentHash := NewConsistentHash(server, originalStrategy) + + assert.NotNil(t, consistentHash) + assert.Equal(t, originalStrategy, consistentHash.originalStrategy) + assert.True(t, consistentHash.useSourceIP) + assert.NotNil(t, consistentHash.hashMap) +} + +// TestConsistentHashNextProxyUseSourceIpExists ensures that when useSourceIp is enabled, +// and the hashed IP exists in the hashMap, the correct proxy is returned. +// It mocks a connection with a specific IP and verifies the proxy retrieval from the hashMap. +func TestConsistentHashNextProxyUseSourceIpExists(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{ + Proxies: proxies, + LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true}, + } + originalStrategy := NewRandom(server) + consistentHash := NewConsistentHash(server, originalStrategy) + mockConn := new(MockConnWrapper) + + // Mock LocalAddr to return a specific IP:port format + mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + mockConn.On("LocalAddr").Return(mockAddr) + + key := "192.168.1.1" + hash := hashKey(key) + + consistentHash.hashMap[hash] = proxies[2] + + proxy, err := consistentHash.NextProxy(mockConn) + assert.Nil(t, err) + assert.Equal(t, proxies[2], proxy) + + // Clean up + mockConn.AssertExpectations(t) +} + +// TestConsistentHashNextProxyUseFullAddress verifies the behavior when useSourceIp is disabled. +// It ensures that the full connection address is used for hashing, and the correct proxy is returned +// and cached in the hashMap. The test also checks that the hash value is computed based on the full address. +func TestConsistentHashNextProxyUseFullAddress(t *testing.T) { + mockConn := new(MockConnWrapper) + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{ + Proxies: proxies, + LoadbalancerConsistentHash: &config.ConsistentHash{ + UseSourceIP: false, + }, + } + mockStrategy := NewRoundRobin(server) + + // Mock LocalAddr to return full address + mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + mockConn.On("LocalAddr").Return(mockAddr) + + consistentHash := NewConsistentHash(server, mockStrategy) + + proxy, err := consistentHash.NextProxy(mockConn) + assert.Nil(t, err) + assert.NotNil(t, proxy) + assert.Equal(t, proxies[1], proxy) + + // Hash should be calculated using the full address and cached in hashMap + hash := hashKey("192.168.1.1:1234") + cachedProxy, exists := consistentHash.hashMap[hash] + + assert.True(t, exists) + assert.Equal(t, proxies[1], cachedProxy) + + // Clean up + mockConn.AssertExpectations(t) +} + +// TestConsistentHashNextProxyConcurrency tests the concurrency safety of the NextProxy method +// in the ConsistentHash struct. It ensures that multiple goroutines can concurrently call +// NextProxy without causing race conditions or inconsistent behavior. +func TestConsistentHashNextProxyConcurrency(t *testing.T) { + // Setup mocks + conn1 := new(MockConnWrapper) + conn2 := new(MockConnWrapper) + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{ + Proxies: proxies, + LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true}, + } + originalStrategy := NewRoundRobin(server) + + // Mock IP addresses + mockAddr1 := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + mockAddr2 := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 1234} + conn1.On("LocalAddr").Return(mockAddr1) + conn2.On("LocalAddr").Return(mockAddr2) + + // Initialize the ConsistentHash + consistentHash := NewConsistentHash(server, originalStrategy) + + // Run the test concurrently + var waitGroup sync.WaitGroup + const numGoroutines = 100 + + for range numGoroutines { + waitGroup.Add(1) + go func() { + defer waitGroup.Done() + p, err := consistentHash.NextProxy(conn1) + assert.Nil(t, err) + assert.Equal(t, proxies[1], p) + }() + } + + waitGroup.Wait() + + // Ensure that the proxy is consistently the same + proxy, err := consistentHash.NextProxy(conn1) + assert.Nil(t, err) + assert.Equal(t, proxies[1], proxy) + + // Ensure that connecting from a different address returns a different proxy + proxy, err = consistentHash.NextProxy(conn2) + assert.Nil(t, err) + assert.Equal(t, proxies[2], proxy) +} diff --git a/network/loadbalancer.go b/network/loadbalancer.go index 08dd324f..2ce6c5cf 100644 --- a/network/loadbalancer.go +++ b/network/loadbalancer.go @@ -6,7 +6,7 @@ import ( ) type LoadBalancerStrategy interface { - NextProxy() (IProxy, *gerr.GatewayDError) + NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError) } // NewLoadBalancerStrategy returns a LoadBalancerStrategy based on the server's load balancer strategy name. @@ -14,20 +14,28 @@ type LoadBalancerStrategy interface { // it selects a load balancer rule before returning the strategy. // Returns an error if the strategy is not found or if there are no load balancer rules when required. func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) { + var strategy LoadBalancerStrategy switch server.LoadbalancerStrategyName { case config.RoundRobinStrategy: - return NewRoundRobin(server), nil + strategy = NewRoundRobin(server) case config.RANDOMStrategy: - return NewRandom(server), nil + strategy = NewRandom(server) case config.WeightedRoundRobinStrategy: if server.LoadbalancerRules == nil { return nil, gerr.ErrNoLoadBalancerRules } loadbalancerRule := selectLoadBalancerRule(server.LoadbalancerRules) - return NewWeightedRoundRobin(server, loadbalancerRule), nil + strategy = NewWeightedRoundRobin(server, loadbalancerRule) default: return nil, gerr.ErrLoadBalancerStrategyNotFound } + + // If consistent hashing is enabled, wrap the strategy + if server.LoadbalancerConsistentHash != nil { + strategy = NewConsistentHash(server, strategy) + } + + return strategy, nil } // selectLoadBalancerRule selects and returns the first load balancer rule that matches the default condition. diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index e8e44cf5..02ae07eb 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -2,12 +2,15 @@ package network import ( "encoding/binary" + "fmt" + "net" "strings" "testing" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -208,3 +211,67 @@ func (m MockProxy) BusyConnectionsString() []string { func (m MockProxy) GetName() string { return m.name } + +// Mock implementation of IConnWrapper. +type MockConnWrapper struct { + mock.Mock +} + +func (m *MockConnWrapper) Conn() net.Conn { + args := m.Called() + conn, ok := args.Get(0).(net.Conn) + if !ok { + panic(fmt.Sprintf("expected net.Conn but got %T", args.Get(0))) + } + return conn +} + +func (m *MockConnWrapper) UpgradeToTLS(upgrader UpgraderFunc) *gerr.GatewayDError { + args := m.Called(upgrader) + err, ok := args.Get(0).(*gerr.GatewayDError) + if !ok { + panic(fmt.Sprintf("expected *gerr.GatewayDError but got %T", args.Get(0))) + } + return err +} + +func (m *MockConnWrapper) Close() error { + args := m.Called() + if err := args.Error(0); err != nil { + return fmt.Errorf("failed to close connection: %w", err) + } + return nil +} + +func (m *MockConnWrapper) Write(data []byte) (int, error) { + args := m.Called(data) + return args.Int(0), args.Error(1) +} + +func (m *MockConnWrapper) Read(data []byte) (int, error) { + args := m.Called(data) + return args.Int(0), args.Error(1) +} + +func (m *MockConnWrapper) RemoteAddr() net.Addr { + args := m.Called() + addr, ok := args.Get(0).(net.Addr) + if !ok { + panic(fmt.Sprintf("expected net.Addr but got %T", args.Get(0))) + } + return addr +} + +func (m *MockConnWrapper) LocalAddr() net.Addr { + args := m.Called() + addr, ok := args.Get(0).(net.Addr) + if !ok { + panic(fmt.Sprintf("expected net.Addr but got %T", args.Get(0))) + } + return addr +} + +func (m *MockConnWrapper) IsTLSEnabled() bool { + args := m.Called() + return args.Bool(0) +} diff --git a/network/random.go b/network/random.go index 1c78f20d..5aea61c3 100644 --- a/network/random.go +++ b/network/random.go @@ -24,7 +24,7 @@ func NewRandom(server *Server) *Random { } // NextProxy returns a random proxy from the list. -func (r *Random) NextProxy() (IProxy, *gerr.GatewayDError) { +func (r *Random) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) { r.mu.Lock() defer r.mu.Unlock() diff --git a/network/random_test.go b/network/random_test.go index 65fc8f32..49bccd79 100644 --- a/network/random_test.go +++ b/network/random_test.go @@ -34,7 +34,7 @@ func TestGetNextProxy(t *testing.T) { server := &Server{Proxies: proxies} random := NewRandom(server) - proxy, err := random.NextProxy() + proxy, err := random.NextProxy(nil) assert.Nil(t, err) assert.Contains(t, proxies, proxy) @@ -44,7 +44,7 @@ func TestGetNextProxy(t *testing.T) { server := &Server{Proxies: []IProxy{}} random := NewRandom(server) - proxy, err := random.NextProxy() + proxy, err := random.NextProxy(nil) assert.Nil(t, proxy) assert.Equal(t, gerr.ErrNoProxiesAvailable.Message, err.Message) @@ -54,8 +54,8 @@ func TestGetNextProxy(t *testing.T) { server := &Server{Proxies: proxies} random := NewRandom(server) - proxy1, _ := random.NextProxy() - proxy2, _ := random.NextProxy() + proxy1, _ := random.NextProxy(nil) + proxy2, _ := random.NextProxy(nil) assert.Contains(t, proxies, proxy1) assert.Contains(t, proxies, proxy2) @@ -81,7 +81,7 @@ func TestConcurrencySafety(t *testing.T) { waitGroup.Add(1) go func() { defer waitGroup.Done() - proxy, _ := random.NextProxy() + proxy, _ := random.NextProxy(nil) proxyChan <- proxy }() } diff --git a/network/roundrobin.go b/network/roundrobin.go index 0057f432..c8235e32 100644 --- a/network/roundrobin.go +++ b/network/roundrobin.go @@ -16,7 +16,7 @@ func NewRoundRobin(server *Server) *RoundRobin { return &RoundRobin{proxies: server.Proxies} } -func (r *RoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) { +func (r *RoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) { proxiesLen := uint32(len(r.proxies)) if proxiesLen == 0 { return nil, gerr.ErrNoProxiesAvailable.Wrap(errors.New("proxy list is empty")) diff --git a/network/roundrobin_test.go b/network/roundrobin_test.go index ec430dc9..daed0f2c 100644 --- a/network/roundrobin_test.go +++ b/network/roundrobin_test.go @@ -36,7 +36,7 @@ func TestRoundRobin_NextProxy(t *testing.T) { expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"} for testIndex, expected := range expectedOrder { - proxy, err := roundRobin.NextProxy() + proxy, err := roundRobin.NextProxy(nil) if err != nil { t.Fatalf("test %d: unexpected error from NextProxy: %v", testIndex, err) } @@ -68,7 +68,7 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) { for range numGoroutines { go func() { defer waitGroup.Done() - _, _ = roundRobin.NextProxy() + _, _ = roundRobin.NextProxy(nil) }() } @@ -99,7 +99,7 @@ func TestNextProxyOverflow(t *testing.T) { // Call NextProxy multiple times to trigger the overflow for range 4 { - proxy, err := roundRobin.NextProxy() + proxy, err := roundRobin.NextProxy(nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/network/server.go b/network/server.go index 8271ed37..a1b50deb 100644 --- a/network/server.go +++ b/network/server.go @@ -75,10 +75,11 @@ type Server struct { stopServer chan struct{} // loadbalancer - loadbalancerStrategy LoadBalancerStrategy - LoadbalancerStrategyName string - LoadbalancerRules []config.LoadBalancingRule - connectionToProxyMap map[*ConnWrapper]IProxy + loadbalancerStrategy LoadBalancerStrategy + LoadbalancerStrategyName string + LoadbalancerRules []config.LoadBalancingRule + LoadbalancerConsistentHash *config.ConsistentHash + connectionToProxyMap map[*ConnWrapper]IProxy } var _ IServer = (*Server)(nil) @@ -156,7 +157,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { span.AddEvent("Ran the OnOpening hooks") // Attempt to retrieve the next proxy. - proxy, err := s.loadbalancerStrategy.NextProxy() + proxy, err := s.loadbalancerStrategy.NextProxy(conn) if err != nil { span.RecordError(err) s.Logger.Error().Err(err).Msg("failed to retrieve next proxy") @@ -677,27 +678,28 @@ func NewServer( // Create the server. server := Server{ - ctx: serverCtx, - Network: srv.Network, - Address: srv.Address, - Options: srv.Options, - TickInterval: srv.TickInterval, - Status: config.Stopped, - EnableTLS: srv.EnableTLS, - CertFile: srv.CertFile, - KeyFile: srv.KeyFile, - HandshakeTimeout: srv.HandshakeTimeout, - Proxies: srv.Proxies, - Logger: srv.Logger, - PluginRegistry: srv.PluginRegistry, - PluginTimeout: srv.PluginTimeout, - mu: &sync.RWMutex{}, - connections: 0, - running: &atomic.Bool{}, - stopServer: make(chan struct{}), - connectionToProxyMap: make(map[*ConnWrapper]IProxy), - LoadbalancerStrategyName: srv.LoadbalancerStrategyName, - LoadbalancerRules: srv.LoadbalancerRules, + ctx: serverCtx, + Network: srv.Network, + Address: srv.Address, + Options: srv.Options, + TickInterval: srv.TickInterval, + Status: config.Stopped, + EnableTLS: srv.EnableTLS, + CertFile: srv.CertFile, + KeyFile: srv.KeyFile, + HandshakeTimeout: srv.HandshakeTimeout, + Proxies: srv.Proxies, + Logger: srv.Logger, + PluginRegistry: srv.PluginRegistry, + PluginTimeout: srv.PluginTimeout, + mu: &sync.RWMutex{}, + connections: 0, + running: &atomic.Bool{}, + stopServer: make(chan struct{}), + connectionToProxyMap: make(map[*ConnWrapper]IProxy), + LoadbalancerStrategyName: srv.LoadbalancerStrategyName, + LoadbalancerRules: srv.LoadbalancerRules, + LoadbalancerConsistentHash: srv.LoadbalancerConsistentHash, } // Try to resolve the address and log an error if it can't be resolved. diff --git a/network/weightedroundrobin.go b/network/weightedroundrobin.go index a7cdca30..0b10fcd8 100644 --- a/network/weightedroundrobin.go +++ b/network/weightedroundrobin.go @@ -50,7 +50,7 @@ func NewWeightedRoundRobin(server *Server, loadbalancerRule config.LoadBalancing // It adjusts the current weight of each proxy based on its effective weight and selects // the proxy with the highest current weight. The selected proxy's current weight is then // decreased by the total weight of all proxies to ensure balanced distribution over time. -func (r *WeightedRoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) { +func (r *WeightedRoundRobin) NextProxy(_ IConnWrapper) (IProxy, *gerr.GatewayDError) { r.mu.Lock() defer r.mu.Unlock() diff --git a/network/weightedroundrobin_test.go b/network/weightedroundrobin_test.go index 1f5618ec..505bafa6 100644 --- a/network/weightedroundrobin_test.go +++ b/network/weightedroundrobin_test.go @@ -132,7 +132,7 @@ func TestWeightedRoundRobinNextProxy(t *testing.T) { // Simulate the distribution of requests using the WeightedRoundRobin algorithm. for range totalRequests { - proxy, err := weightedRR.NextProxy() + proxy, err := weightedRR.NextProxy(nil) require.Nil(t, err) mockProxy, ok := proxy.(MockProxy) @@ -187,7 +187,7 @@ func TestWeightedRoundRobinConcurrentAccess(t *testing.T) { defer waitGroup.Done() // Retrieve the next proxy using the WeightedRoundRobin algorithm. - proxy, err := weightedRR.NextProxy() + proxy, err := weightedRR.NextProxy(nil) if assert.Nil(t, err, "No error expected when getting a proxy") { // Safely update the proxy selection count using a mutex. mux.Lock()