Skip to content

Commit

Permalink
Add Consistent Hashing Load Balancing Strategy (#592)
Browse files Browse the repository at this point in the history
* Added ConsistentHash
* changed net conn into Iconnwrapper
* added test cases for onsistentHash
* fixed lint issues
* replace RWMutex to Mutex and added test case for concurency access
* added github.com/spaolacci/murmur3 into depguard
  • Loading branch information
sinadarbouy authored Aug 19, 2024
1 parent bfb0a81 commit 54d0c69
Show file tree
Hide file tree
Showing 17 changed files with 385 additions and 53 deletions.
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ linters-settings:
- "golang.org/x/text/cases"
- "golang.org/x/text/language"
- "github.com/redis/go-redis/v9"
- "github.com/spaolacci/murmur3"
test:
files:
- $test
Expand Down
21 changes: 11 additions & 10 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,17 @@ var runCmd = &cobra.Command{
// Can be used to send keepalive messages to the client.
EnableTicker: cfg.EnableTicker,
},
Proxies: serverProxies,
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
Proxies: serverProxies,
Logger: logger,
PluginRegistry: pluginRegistry,
PluginTimeout: conf.Plugin.Timeout,
EnableTLS: cfg.EnableTLS,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
HandshakeTimeout: cfg.HandshakeTimeout,
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash,
},
)

Expand Down
5 changes: 5 additions & 0 deletions config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ type LoadBalancingRule struct {
Distribution []Distribution `json:"distribution"`
}

type ConsistentHash struct {
UseSourceIP bool `json:"useSourceIp"`
}

type LoadBalancer struct {
Strategy string `json:"strategy"`
LoadBalancingRules []LoadBalancingRule `json:"loadBalancingRules"`
ConsistentHash *ConsistentHash `json:"consistentHash,omitempty"`
}

type Server struct {
Expand Down
2 changes: 2 additions & 0 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ servers:
loadBalancer:
# Load balancer strategies can be found in config/constants.go
strategy: ROUND_ROBIN # ROUND_ROBIN, RANDOM, WEIGHTED_ROUND_ROBIN
consistentHash:
useSourceIp: true
# Optional configuration for strategies that support rules (e.g., WEIGHTED_ROUND_ROBIN)
# loadBalancingRules:
# - condition: "DEFAULT" # Currently, only the "DEFAULT" condition is supported
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ require (
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/skeema/knownhosts v1.2.1 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 87 additions & 0 deletions network/consistenthash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package network

import (
"fmt"
"net"
"sync"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/spaolacci/murmur3"
)

// ConsistentHash implements a load balancing strategy based on consistent hashing.
// It routes client connections to specific proxies by hashing the client's IP address or the full connection address.
type ConsistentHash struct {
originalStrategy LoadBalancerStrategy
useSourceIP bool
hashMap map[uint64]IProxy
mu sync.Mutex
}

// NewConsistentHash creates a new ConsistentHash instance. It requires a server configuration and an original
// load balancing strategy. The consistent hash can use either the source IP or the full connection address
// as the key for hashing.
func NewConsistentHash(server *Server, originalStrategy LoadBalancerStrategy) *ConsistentHash {
return &ConsistentHash{
originalStrategy: originalStrategy,
useSourceIP: server.LoadbalancerConsistentHash.UseSourceIP,
hashMap: make(map[uint64]IProxy),
}
}

// NextProxy selects the appropriate proxy for a given client connection. It first tries to find an existing
// proxy in the hash map based on the hashed key (either the source IP or the full address). If no match is found,
// it falls back to the original load balancing strategy, adds the selected proxy to the hash map, and returns it.
func (ch *ConsistentHash) NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError) {
ch.mu.Lock()
defer ch.mu.Unlock()

var key string

if ch.useSourceIP {
sourceIP, err := extractIPFromConn(conn)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}
key = sourceIP
} else {
key = conn.LocalAddr().String() // Fallback to use full address as the key if `useSourceIp` is false
}

hash := hashKey(key)

proxy, exists := ch.hashMap[hash]

if exists {
return proxy, nil
}

// If no hash exists, fallback to the original strategy
proxy, err := ch.originalStrategy.NextProxy(conn)
if err != nil {
return nil, gerr.ErrNoProxiesAvailable.Wrap(err)
}

// Add the selected proxy to the hash map for future requests
ch.hashMap[hash] = proxy

return proxy, nil
}

// hashKey hashes a given key using the MurmurHash3 algorithm. It is used to generate consistent hash values
// for IP addresses or connection strings.
func hashKey(key string) uint64 {
return murmur3.Sum64([]byte(key))
}

// extractIPFromConn extracts the IP address from the connection's local address. It splits the address
// into IP and port components and returns the IP part. This is useful for hashing based on the source IP.
func extractIPFromConn(con IConnWrapper) (string, error) {
addr := con.LocalAddr().String()
// addr will be in the format "IP:port"
ip, _, err := net.SplitHostPort(addr)
if err != nil {
return "", fmt.Errorf("failed to split host and port from address %s: %w", addr, err)
}
return ip, nil
}
154 changes: 154 additions & 0 deletions network/consistenthash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package network

import (
"net"
"sync"
"testing"

"github.com/gatewayd-io/gatewayd/config"
"github.com/stretchr/testify/assert"
)

// TestNewConsistentHash verifies that a new ConsistentHash instance is properly created.
// It checks that the original load balancing strategy is preserved, that the useSourceIp
// setting is correctly applied, and that the hashMap is initialized.
func TestNewConsistentHash(t *testing.T) {
server := &Server{
LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true},
}
originalStrategy := NewRandom(server)
consistentHash := NewConsistentHash(server, originalStrategy)

assert.NotNil(t, consistentHash)
assert.Equal(t, originalStrategy, consistentHash.originalStrategy)
assert.True(t, consistentHash.useSourceIP)
assert.NotNil(t, consistentHash.hashMap)
}

// TestConsistentHashNextProxyUseSourceIpExists ensures that when useSourceIp is enabled,
// and the hashed IP exists in the hashMap, the correct proxy is returned.
// It mocks a connection with a specific IP and verifies the proxy retrieval from the hashMap.
func TestConsistentHashNextProxyUseSourceIpExists(t *testing.T) {
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{
Proxies: proxies,
LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true},
}
originalStrategy := NewRandom(server)
consistentHash := NewConsistentHash(server, originalStrategy)
mockConn := new(MockConnWrapper)

// Mock LocalAddr to return a specific IP:port format
mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
mockConn.On("LocalAddr").Return(mockAddr)

key := "192.168.1.1"
hash := hashKey(key)

consistentHash.hashMap[hash] = proxies[2]

proxy, err := consistentHash.NextProxy(mockConn)
assert.Nil(t, err)
assert.Equal(t, proxies[2], proxy)

// Clean up
mockConn.AssertExpectations(t)
}

// TestConsistentHashNextProxyUseFullAddress verifies the behavior when useSourceIp is disabled.
// It ensures that the full connection address is used for hashing, and the correct proxy is returned
// and cached in the hashMap. The test also checks that the hash value is computed based on the full address.
func TestConsistentHashNextProxyUseFullAddress(t *testing.T) {
mockConn := new(MockConnWrapper)
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{
Proxies: proxies,
LoadbalancerConsistentHash: &config.ConsistentHash{
UseSourceIP: false,
},
}
mockStrategy := NewRoundRobin(server)

// Mock LocalAddr to return full address
mockAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
mockConn.On("LocalAddr").Return(mockAddr)

consistentHash := NewConsistentHash(server, mockStrategy)

proxy, err := consistentHash.NextProxy(mockConn)
assert.Nil(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, proxies[1], proxy)

// Hash should be calculated using the full address and cached in hashMap
hash := hashKey("192.168.1.1:1234")
cachedProxy, exists := consistentHash.hashMap[hash]

assert.True(t, exists)
assert.Equal(t, proxies[1], cachedProxy)

// Clean up
mockConn.AssertExpectations(t)
}

// TestConsistentHashNextProxyConcurrency tests the concurrency safety of the NextProxy method
// in the ConsistentHash struct. It ensures that multiple goroutines can concurrently call
// NextProxy without causing race conditions or inconsistent behavior.
func TestConsistentHashNextProxyConcurrency(t *testing.T) {
// Setup mocks
conn1 := new(MockConnWrapper)
conn2 := new(MockConnWrapper)
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
MockProxy{name: "proxy3"},
}
server := &Server{
Proxies: proxies,
LoadbalancerConsistentHash: &config.ConsistentHash{UseSourceIP: true},
}
originalStrategy := NewRoundRobin(server)

// Mock IP addresses
mockAddr1 := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}
mockAddr2 := &net.TCPAddr{IP: net.ParseIP("192.168.1.2"), Port: 1234}
conn1.On("LocalAddr").Return(mockAddr1)
conn2.On("LocalAddr").Return(mockAddr2)

// Initialize the ConsistentHash
consistentHash := NewConsistentHash(server, originalStrategy)

// Run the test concurrently
var waitGroup sync.WaitGroup
const numGoroutines = 100

for range numGoroutines {
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
p, err := consistentHash.NextProxy(conn1)
assert.Nil(t, err)
assert.Equal(t, proxies[1], p)
}()
}

waitGroup.Wait()

// Ensure that the proxy is consistently the same
proxy, err := consistentHash.NextProxy(conn1)
assert.Nil(t, err)
assert.Equal(t, proxies[1], proxy)

// Ensure that connecting from a different address returns a different proxy
proxy, err = consistentHash.NextProxy(conn2)
assert.Nil(t, err)
assert.Equal(t, proxies[2], proxy)
}
16 changes: 12 additions & 4 deletions network/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,36 @@ import (
)

type LoadBalancerStrategy interface {
NextProxy() (IProxy, *gerr.GatewayDError)
NextProxy(conn IConnWrapper) (IProxy, *gerr.GatewayDError)
}

// NewLoadBalancerStrategy returns a LoadBalancerStrategy based on the server's load balancer strategy name.
// If the server's load balancer strategy is weighted round-robin,
// it selects a load balancer rule before returning the strategy.
// Returns an error if the strategy is not found or if there are no load balancer rules when required.
func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) {
var strategy LoadBalancerStrategy
switch server.LoadbalancerStrategyName {
case config.RoundRobinStrategy:
return NewRoundRobin(server), nil
strategy = NewRoundRobin(server)
case config.RANDOMStrategy:
return NewRandom(server), nil
strategy = NewRandom(server)
case config.WeightedRoundRobinStrategy:
if server.LoadbalancerRules == nil {
return nil, gerr.ErrNoLoadBalancerRules
}
loadbalancerRule := selectLoadBalancerRule(server.LoadbalancerRules)
return NewWeightedRoundRobin(server, loadbalancerRule), nil
strategy = NewWeightedRoundRobin(server, loadbalancerRule)
default:
return nil, gerr.ErrLoadBalancerStrategyNotFound
}

// If consistent hashing is enabled, wrap the strategy
if server.LoadbalancerConsistentHash != nil {
strategy = NewConsistentHash(server, strategy)
}

return strategy, nil
}

// selectLoadBalancerRule selects and returns the first load balancer rule that matches the default condition.
Expand Down
Loading

0 comments on commit 54d0c69

Please sign in to comment.