Skip to content

Commit

Permalink
fixed review problems
Browse files Browse the repository at this point in the history
  • Loading branch information
sinadarbouy committed Jul 23, 2024
1 parent 0e171fe commit 016dac1
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 56 deletions.
5 changes: 5 additions & 0 deletions config/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,8 @@ const (
DefaultRedisAddress = "localhost:6379"
DefaultRedisChannel = "gatewayd-actions"
)

// Load balancing strategies.
const (
RoundRobinStrategy = "ROUND_ROBIN"
)
7 changes: 6 additions & 1 deletion errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const (
ErrCodeConfigParseError
ErrCodePublishAsyncAction
ErrCodeLoadBalancerStrategyNotFound
ErrCodeNoProxiesAvailable
)

var (
Expand Down Expand Up @@ -196,7 +197,11 @@ var (
}

ErrLoadBalancerStrategyNotFound = &GatewayDError{
ErrorCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil,
ErrCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil,
}

ErrNoProxiesAvailable = &GatewayDError{
ErrCodeNoProxiesAvailable, "No proxies available to select.", nil,
}

// Unwrapped errors.
Expand Down
8 changes: 0 additions & 8 deletions gatewayd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ servers:
- "default-2"
loadBalancer:
strategy: ROUND_ROBIN
# Not yet implemented.
# loadBalancingRules:
# - condition: "default"
# percentages:
# - proxy: "default"
# percentage: 70
# - proxy: "default-2"
# percentage: 30
enableTicker: False
tickInterval: 5s # duration
enableTLS: False
Expand Down
6 changes: 0 additions & 6 deletions network/constants.go

This file was deleted.

5 changes: 3 additions & 2 deletions network/loadbalancer.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package network

import (
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
)

type LoadBalancerStrategy interface {
GetNextProxy() IProxy
NextProxy() (IProxy, *gerr.GatewayDError)
}

func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) {
switch server.LoadbalancerStrategyName {
case RoundRobinStrategy:
case config.RoundRobinStrategy:
return NewRoundRobin(server), nil
default:
return nil, gerr.ErrLoadBalancerStrategyNotFound
Expand Down
6 changes: 5 additions & 1 deletion network/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import (
"errors"
"testing"

"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
)

// TestNewLoadBalancerStrategy tests the NewLoadBalancerStrategy function to ensure it correctly
// initializes the load balancer strategy based on the strategy name provided in the server configuration.
// It covers both valid and invalid strategy names.
func TestNewLoadBalancerStrategy(t *testing.T) {
serverValid := &Server{
LoadbalancerStrategyName: RoundRobinStrategy,
LoadbalancerStrategyName: config.RoundRobinStrategy,
Proxies: []IProxy{MockProxy{}},
}

Expand Down
54 changes: 54 additions & 0 deletions network/network_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"
"testing"

gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
Expand All @@ -16,6 +17,11 @@ type WriteBuffer struct {
msgStart int
}

// MockProxy implements the IProxy interface for testing purposes.
type MockProxy struct {
name string
}

// writeStartupMsg writes a PostgreSQL startup message to the buffer.
func writeStartupMsg(buf *WriteBuffer, user, database, appName string) {
// Write startup message header
Expand Down Expand Up @@ -154,3 +160,51 @@ func CollectAndComparePrometheusMetrics(t *testing.T) {
require.NoError(t,
testutil.GatherAndCompare(prometheus.DefaultGatherer, strings.NewReader(want), metrics...))
}

// Connect is a mock implementation of the Connect method in the IProxy interface.
func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError {
return nil
}

// Disconnect is a mock implementation of the Disconnect method in the IProxy interface.
func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError {
return nil
}

// PassThroughToServer is a mock implementation of the PassThroughToServer method in the IProxy interface.
func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError {
return nil
}

// PassThroughToClient is a mock implementation of the PassThroughToClient method in the IProxy interface.
func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError {
return nil
}

// IsHealthy is a mock implementation of the IsHealthy method in the IProxy interface.
func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) {
return nil, nil
}

// IsExhausted is a mock implementation of the IsExhausted method in the IProxy interface.
func (m MockProxy) IsExhausted() bool {
return false
}

// Shutdown is a mock implementation of the Shutdown method in the IProxy interface.
func (m MockProxy) Shutdown() {}

// AvailableConnectionsString is a mock implementation of the AvailableConnectionsString method in the IProxy interface.
func (m MockProxy) AvailableConnectionsString() []string {
return nil
}

// BusyConnectionsString is a mock implementation of the BusyConnectionsString method in the IProxy interface.
func (m MockProxy) BusyConnectionsString() []string {
return nil
}

// GetName returns the name of the MockProxy.
func (m MockProxy) GetName() string {
return m.name
}
16 changes: 11 additions & 5 deletions network/round-robin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package network

import "sync/atomic"
import (
"errors"
"sync/atomic"

gerr "github.com/gatewayd-io/gatewayd/errors"
)

type RoundRobin struct {
proxies []IProxy
Expand All @@ -11,10 +16,11 @@ func NewRoundRobin(server *Server) *RoundRobin {
return &RoundRobin{proxies: server.Proxies}
}

func (r *RoundRobin) GetNextProxy() IProxy {
func (r *RoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) {
proxiesLen := uint32(len(r.proxies))

if proxiesLen == 0 {
return nil, gerr.ErrNoProxiesAvailable.Wrap(errors.New("proxy list is empty"))
}
nextIndex := r.next.Add(1)

return r.proxies[nextIndex%proxiesLen]
return r.proxies[nextIndex%proxiesLen], nil
}
79 changes: 50 additions & 29 deletions network/round-robin_test.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
package network

import (
"math"
"sync"
"testing"

gerr "github.com/gatewayd-io/gatewayd/errors"
)

// MockProxy implements IProxy interface for testing.
type MockProxy struct {
name string
}

func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError { return nil }
func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError { return nil }
func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError {
return nil
}

func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError {
return nil
}
func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) { return nil, nil }
func (m MockProxy) IsExhausted() bool { return false }
func (m MockProxy) Shutdown() {}
func (m MockProxy) AvailableConnectionsString() []string { return nil }
func (m MockProxy) BusyConnectionsString() []string { return nil }

func (m MockProxy) GetName() string {
return m.name
}

// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes
// the round-robin load balancer with the expected number of proxies.
func TestNewRoundRobin(t *testing.T) {
proxies := []IProxy{
MockProxy{name: "proxy1"},
Expand All @@ -45,7 +22,9 @@ func TestNewRoundRobin(t *testing.T) {
}
}

func TestRoundRobin_GetNextProxy(t *testing.T) {
// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure
// that it returns proxies in the expected order.
func TestRoundRobin_NextProxy(t *testing.T) {
proxies := []IProxy{
MockProxy{name: "proxy1"},
MockProxy{name: "proxy2"},
Expand All @@ -57,7 +36,10 @@ func TestRoundRobin_GetNextProxy(t *testing.T) {
expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"}

for testIndex, expected := range expectedOrder {
proxy := roundRobin.GetNextProxy()
proxy, err := roundRobin.NextProxy()
if err != nil {
t.Fatalf("test %d: unexpected error from NextProxy: %v", testIndex, err)
}
mockProxy, ok := proxy.(MockProxy)
if !ok {
t.Fatalf("test %d: expected proxy of type MockProxy, got %T", testIndex, proxy)
Expand All @@ -68,6 +50,8 @@ func TestRoundRobin_GetNextProxy(t *testing.T) {
}
}

// TestRoundRobin_ConcurrentAccess tests the thread safety of the NextProxy method in the round-robin load balancer
// by invoking it concurrently from multiple goroutines and ensuring that the internal state is updated correctly.
func TestRoundRobin_ConcurrentAccess(t *testing.T) {
proxies := []IProxy{
MockProxy{name: "proxy1"},
Expand All @@ -84,7 +68,7 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
for range numGoroutines {
go func() {
defer waitGroup.Done()
_ = roundRobin.GetNextProxy()
_, _ = roundRobin.NextProxy()
}()
}

Expand All @@ -94,3 +78,40 @@ func TestRoundRobin_ConcurrentAccess(t *testing.T) {
t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex)
}
}

// TestNextProxyOverflow verifies that the round-robin proxy selection correctly handles
// the overflow of the internal counter. It sets the counter to a value close to the maximum
// uint32 value and ensures that the proxy selection wraps around as expected when the
// counter overflows.
func TestNextProxyOverflow(t *testing.T) {
// Create a server with a few mock proxies
server := &Server{
Proxies: []IProxy{
&MockProxy{},
&MockProxy{},
&MockProxy{},
},
}
roundRobin := NewRoundRobin(server)

// Set the next value to near the max uint32 value to force an overflow
roundRobin.next.Store(math.MaxUint32 - 1)

// Call NextProxy multiple times to trigger the overflow
for range 4 {
proxy, err := roundRobin.NextProxy()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if proxy == nil {
t.Fatal("Expected a proxy, got nil")
}
}

// After overflow, next value should wrap around
expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2
actualNextValue := roundRobin.next.Load()
if actualNextValue != expectedNextValue {
t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue)
}
}
13 changes: 10 additions & 3 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,13 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
}
span.AddEvent("Ran the OnOpening hooks")

// get proxy
proxy := s.loadbalancerStrategy.GetNextProxy()
// Attempt to retrieve the next proxy.
proxy, err := s.loadbalancerStrategy.NextProxy()
if err != nil {
span.RecordError(err)
s.Logger.Error().Err(err).Msg("failed to retrieve next proxy")
return nil, Close
}

// Use the proxy to connect to the backend. Close the connection if the pool is exhausted.
// This effectively get a connection from the pool and puts both the incoming and the server
Expand All @@ -173,7 +178,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
return nil, None
}

// AssignConnectionToProxy
// Assign connection to proxy
s.connectionToProxyMap[conn] = proxy

// Run the OnOpened hooks.
Expand Down Expand Up @@ -331,6 +336,7 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti
proxy, exists := server.GetProxyForConnection(conn)
if !exists {
server.Logger.Error().Msg("Failed to find proxy that matches the connection")
stopConnection <- struct{}{}
break
}

Expand All @@ -353,6 +359,7 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti
proxy, exists := server.GetProxyForConnection(conn)
if !exists {
server.Logger.Error().Msg("Failed to find proxy that matches the connection")
stopConnection <- struct{}{}
break
}
if err := proxy.PassThroughToClient(conn, stack); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestRunServer(t *testing.T) {
PluginRegistry: pluginRegistry,
PluginTimeout: config.DefaultPluginTimeout,
HandshakeTimeout: config.DefaultHandshakeTimeout,
LoadbalancerStrategyName: RoundRobinStrategy,
LoadbalancerStrategyName: config.RoundRobinStrategy,
},
)
assert.NotNil(t, server)
Expand Down

0 comments on commit 016dac1

Please sign in to comment.