From 5717cada8b6a0257ac583d218290208e8325aa78 Mon Sep 17 00:00:00 2001 From: sinadarbouy Date: Tue, 16 Jul 2024 22:07:36 +0200 Subject: [PATCH] Added tests for loadbalancer and round robin --- network/loadbalancer_test.go | 39 ++++++++++++++++ network/round-robin_test.go | 91 ++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 network/loadbalancer_test.go create mode 100644 network/round-robin_test.go diff --git a/network/loadbalancer_test.go b/network/loadbalancer_test.go new file mode 100644 index 00000000..a496c811 --- /dev/null +++ b/network/loadbalancer_test.go @@ -0,0 +1,39 @@ +package network + +import ( + "testing" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +func TestNewLoadBalancerStrategy(t *testing.T) { + serverValid := &Server{ + LoadbalancerStrategyName: RoundRobinStrategy, + Proxies: []IProxy{MockProxy{}}, + } + + // Test case 1: Valid strategy name + strategy, err := NewLoadBalancerStrategy(serverValid) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + _, ok := strategy.(*RoundRobin) + if !ok { + t.Errorf("Expected strategy to be of type RoundRobin") + } + + // Test case 1: InValid strategy name + serverInvalid := &Server{ + LoadbalancerStrategyName: "InvalidStrategy", + Proxies: []IProxy{MockProxy{}}, + } + + strategy, err = NewLoadBalancerStrategy(serverInvalid) + if err != gerr.ErrLoadBalancerStrategyNotFound { + t.Errorf("Expected ErrLoadBalancerStrategyNotFound, got %v", err) + } + if strategy != nil { + t.Errorf("Expected strategy to be nil for invalid strategy name") + } +} diff --git a/network/round-robin_test.go b/network/round-robin_test.go new file mode 100644 index 00000000..a6260287 --- /dev/null +++ b/network/round-robin_test.go @@ -0,0 +1,91 @@ +package network + +import ( + "sync" + "testing" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +// MockProxy implements IProxy interface for testing +type MockProxy struct { + name string +} + +func (m MockProxy) Connect(conn *ConnWrapper) *gerr.GatewayDError { return nil } +func (m MockProxy) Disconnect(conn *ConnWrapper) *gerr.GatewayDError { return nil } +func (m MockProxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError { + return nil +} +func (m MockProxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError { + return nil +} +func (m MockProxy) IsHealthy(cl *Client) (*Client, *gerr.GatewayDError) { return nil, nil } +func (m MockProxy) IsExhausted() bool { return false } +func (m MockProxy) Shutdown() {} +func (m MockProxy) AvailableConnectionsString() []string { return nil } +func (m MockProxy) BusyConnectionsString() []string { return nil } + +func (m MockProxy) GetName() string { + return m.name +} + +func TestNewRoundRobin(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + rr := NewRoundRobin(server) + + if len(rr.proxies) != len(proxies) { + t.Errorf("expected %d proxies, got %d", len(proxies), len(rr.proxies)) + } +} + +func TestRoundRobin_GetNextProxy(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + rr := NewRoundRobin(server) + + expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"} + + for i, expected := range expectedOrder { + proxy := rr.GetNextProxy() + if proxy.(MockProxy).GetName() != expected { + t.Errorf("test %d: expected proxy name %s, got %s", i, expected, proxy.(MockProxy).GetName()) + } + } +} + +func TestRoundRobin_ConcurrentAccess(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + rr := NewRoundRobin(server) + + var wg sync.WaitGroup + numGoroutines := 100 + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + _ = rr.GetNextProxy() + }() + } + + wg.Wait() + nextIndex := rr.next.Load() + if nextIndex != uint32(numGoroutines) { + t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex) + } +}