Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Consistent Hashing Load Balancing Strategy #592

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,17 @@ var runCmd = &cobra.Command{
// Can be used to send keepalive messages to the client.
EnableTicker: cfg.EnableTicker,
},
Proxies: serverProxies,
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
Proxies: serverProxies,
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash,
},
)

Expand Down
5 changes: 5 additions & 0 deletions config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ type LoadBalancingRule struct {
Distribution []Distribution `json:"distribution"`
}

type ConsistentHash struct {
UseSourceIP bool `json:"useSourceIp"`
}

type LoadBalancer struct {
Strategy string `json:"strategy"`
LoadBalancingRules []LoadBalancingRule `json:"loadBalancingRules"`
ConsistentHash *ConsistentHash `json:"consistentHash,omitempty"`
}

type Server struct {
Expand Down
2 changes: 2 additions & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ servers:
loadBalancer:
# Load balancer strategies can be found in config/constants.go
strategy: ROUND_ROBIN # ROUND_ROBIN, RANDOM, WEIGHTED_ROUND_ROBIN
consistentHash:
useSourceIp: true
# Optional configuration for strategies that support rules (e.g., WEIGHTED_ROUND_ROBIN)
# loadBalancingRules:
# - condition: "DEFAULT" # Currently, only the "DEFAULT" condition is supported
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ require (
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/skeema/knownhosts v1.2.1 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions network/consistenthash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package network

import (
"fmt"
"net"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/spaolacci/murmur3"

Check failure on line 9 in network/consistenthash.go

View workflow job for this annotation

GitHub Actions / Test GatewayD

import 'github.com/spaolacci/murmur3' is not allowed from list 'main' (depguard)
)

// ConsistentHash implements a load balancing strategy based on consistent hashing.
// It routes client connections to specific proxies by hashing the client's IP address or the full connection address.
type ConsistentHash struct {
originalStrategy LoadBalancerStrategy
useSourceIP bool
hashMap map[uint64]IProxy
hashMapMutex sync.RWMutex
}

// NewConsistentHash creates a new ConsistentHash instance. It requires a server configuration and an original
// load balancing strategy. The consistent hash can use either the source IP or the full connection address
// as the key for hashing.
func NewConsistentHash(server *Server, originalStrategy LoadBalancerStrategy) *ConsistentHash {
return &ConsistentHash{
originalStrategy: originalStrategy,
useSourceIP: server.LoadbalancerConsistentHash.UseSourceIP,
hashMap: make(map[uint64]IProxy),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could use sync.Map, which handles locking behind the scenes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, I messed up by just locking the load and store process. Now, I’ve added a sync.Mutex to lock the whole NextProxy connection. also added a test case to check for concurrency. (so I guess we can not use sync.Map here anymore)

}
}

// NextProxy selects the appropriate proxy for a given client connection. It first tries to find an existing
// proxy in the hash map based on the hashed key (either the source IP or the full address). If no match is found,
// it falls back to the original load balancing strategy, adds the selected proxy to the hash map, and returns it.
func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError) {
var key string

if ch.useSourceIP {
sourceIP, err := extractIPFromConn(conn)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}
key = sourceIP
} else {
key = conn.LocalAddr().String() // Fallback to use full address as the key if `useSourceIp` is false
}

hash := hashKey(key)

ch.hashMapMutex.RLock()
proxy, exists := ch.hashMap[hash]
ch.hashMapMutex.RUnlock()

if exists {
return proxy, nil
}

// If no hash exists, fallback to the original strategy
proxy, err := ch.originalStrategy.NextProxy(conn)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

// Add the selected proxy to the hash map for future requests
ch.hashMapMutex.Lock()
ch.hashMap[hash] = proxy
ch.hashMapMutex.Unlock()

return proxy, nil
}

// hashKey hashes a given key using the MurmurHash3 algorithm. It is used to generate consistent hash values
// for IP addresses or connection strings.
func hashKey(key string) uint64 {
return murmur3.Sum64([]byte(key))
}

// extractIPFromConn extracts the IP address from the connection's local address. It splits the address
// into IP and port components and returns the IP part. This is useful for hashing based on the source IP.
func extractIPFromConn(con IConnWrapper) (string, error) {
addr := con.LocalAddr().String()
// addr will be in the format "IP:port"
ip, _, err := net.SplitHostPort(addr)
if err != nil {
return "", fmt.Errorf("failed to split host and port from address %s: %w", addr, err)
}
return ip, nil
}
101 changes: 101 additions & 0 deletions network/consistenthash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package network

import (
"net"
"testing"

"github.com/gatewayd-io/gatewayd/config"
"github.com/stretchr/testify/assert"
)

// TestNewConsistentHash verifies that a new ConsistentHash instance is properly created.
// It checks that the original load balancing strategy is preserved, that the useSourceIp
// setting is correctly applied, and that the hashMap is initialized.
func TestNewConsistentHash(t *testing.T) {
server := &Server{
LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true},
}
originalStrategy := NewRandom(server)
consistentHash := NewConsistentHash(server, originalStrategy)

assert.NotNil(t, consistentHash)
assert.Equal(t, originalStrategy, consistentHash.originalStrategy)
assert.True(t, consistentHash.useSourceIP)
assert.NotNil(t, consistentHash.hashMap)
}

// TestConsistentHashNextProxyUseSourceIpExists ensures that when useSourceIp is enabled,
// and the hashed IP exists in the hashMap, the correct proxy is returned.
// It mocks a connection with a specific IP and verifies the proxy retrieval from the hashMap.
func TestConsistentHashNextProxyUseSourceIpExists(t *testing.T) {
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{
Proxies: proxies,
LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true},
}
originalStrategy := NewRandom(server)
consistentHash := NewConsistentHash(server, originalStrategy)
mockConn := new(MockConnWrapper)

// Mock LocalAddr to return a specific IP:port format
mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
mockConn.On("LocalAddr").Return(mockAddr)

key := "192.168.1.1"
hash := hashKey(key)

consistentHash.hashMap[hash] = proxies[2]

proxy, err := consistentHash.NextProxy(mockConn)
assert.Nil(t, err)
assert.Equal(t, proxies[2], proxy)

// Clean up
mockConn.AssertExpectations(t)
}

// TestConsistentHashNextProxyUseFullAddress verifies the behavior when useSourceIp is disabled.
// It ensures that the full connection address is used for hashing, and the correct proxy is returned
// and cached in the hashMap. The test also checks that the hash value is computed based on the full address.
func TestConsistentHashNextProxyUseFullAddress(t *testing.T) {
mockConn := new(MockConnWrapper)
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{
Proxies: proxies,
LoadbalancerConsistentHash: &config.ConsistentHash{
UseSourceIP: false,
},
}
mockStrategy := NewRoundRobin(server)

// Mock LocalAddr to return full address
mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
mockConn.On("LocalAddr").Return(mockAddr)

consistentHash := NewConsistentHash(server, mockStrategy)

proxy, err := consistentHash.NextProxy(mockConn)
assert.Nil(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, proxies[1], proxy)

// Hash should be calculated using the full address and cached in hashMap
hash := hashKey("192.168.1.1:1234")
consistentHash.hashMapMutex.RLock()
cachedProxy, exists := consistentHash.hashMap[hash]
consistentHash.hashMapMutex.RUnlock()

assert.True(t, exists)
assert.Equal(t, proxies[1], cachedProxy)

// Clean up
mockConn.AssertExpectations(t)
}
16 changes: 12 additions & 4 deletions network/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,36 @@ import (
)

type LoadBalancerStrategy interface {
NextProxy() (IProxy, *gerr.GatewayDError)
NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError)
}

// NewLoadBalancerStrategy returns a LoadBalancerStrategy based on the server's load balancer strategy name.
// If the server's load balancer strategy is weighted round-robin,
// it selects a load balancer rule before returning the strategy.
// Returns an error if the strategy is not found or if there are no load balancer rules when required.
func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) {
var strategy LoadBalancerStrategy
switch server.LoadbalancerStrategyName {
case config.RoundRobinStrategy:
return NewRoundRobin(server), nil
strategy = NewRoundRobin(server)
case config.RANDOMStrategy:
return NewRandom(server), nil
strategy = NewRandom(server)
case config.WeightedRoundRobinStrategy:
if server.LoadbalancerRules == nil {
return nil, gerr.ErrNoLoadBalancerRules
}
loadbalancerRule := selectLoadBalancerRule(server.LoadbalancerRules)
return NewWeightedRoundRobin(server, loadbalancerRule), nil
strategy = NewWeightedRoundRobin(server, loadbalancerRule)
default:
return nil, gerr.ErrLoadBalancerStrategyNotFound
}

// If consistent hashing is enabled, wrap the strategy
if server.LoadbalancerConsistentHash != nil {
strategy = NewConsistentHash(server, strategy)
}

return strategy, nil
}

// selectLoadBalancerRule selects and returns the first load balancer rule that matches the default condition.
Expand Down
67 changes: 67 additions & 0 deletions network/network_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package network

import (
"encoding/binary"
"fmt"
"net"
"strings"
"testing"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -208,3 +211,67 @@ func (m MockProxy) BusyConnectionsString() []string {
func (m MockProxy) GetName() string {
return m.name
}

// Mock implementation of IConnWrapper.
type MockConnWrapper struct {
mock.Mock
}

func (m *MockConnWrapper) Conn() net.Conn {
args := m.Called()
conn, ok := args.Get(0).(net.Conn)
if !ok {
panic(fmt.Sprintf("expected net.Conn but got %T", args.Get(0)))
}
return conn
}

func (m *MockConnWrapper) UpgradeToTLS(upgrader UpgraderFunc) *gerr.GatewayDError {
args := m.Called(upgrader)
err, ok := args.Get(0).(*gerr.GatewayDError)
if !ok {
panic(fmt.Sprintf("expected *gerr.GatewayDError but got %T", args.Get(0)))
}
return err
}

func (m *MockConnWrapper) Close() error {
args := m.Called()
if err := args.Error(0); err != nil {
return fmt.Errorf("failed to close connection: %w", err)
}
return nil
}

func (m *MockConnWrapper) Write(data []byte) (int, error) {
args := m.Called(data)
return args.Int(0), args.Error(1)
}

func (m *MockConnWrapper) Read(data []byte) (int, error) {
args := m.Called(data)
return args.Int(0), args.Error(1)
}

func (m *MockConnWrapper) RemoteAddr() net.Addr {
args := m.Called()
addr, ok := args.Get(0).(net.Addr)
if !ok {
panic(fmt.Sprintf("expected net.Addr but got %T", args.Get(0)))
}
return addr
}

func (m *MockConnWrapper) LocalAddr() net.Addr {
args := m.Called()
addr, ok := args.Get(0).(net.Addr)
if !ok {
panic(fmt.Sprintf("expected net.Addr but got %T", args.Get(0)))
}
return addr
}

func (m *MockConnWrapper) IsTLSEnabled() bool {
args := m.Called()
return args.Bool(0)
}
Loading
Loading