diff --git a/goLoadBalancer b/goLoadBalancer index 24a99d2..acedf00 100755 Binary files a/goLoadBalancer and b/goLoadBalancer differ diff --git a/loadbalancer.go b/loadbalancer.go index 7ec6d37..0f196e6 100644 --- a/loadbalancer.go +++ b/loadbalancer.go @@ -3,123 +3,123 @@ package main import ( "log" "net/http" - "sync" + "sync/atomic" "time" "github.com/jonboulle/clockwork" ) -type LoadBalancer struct { - Nodes []*Node - Next int - Mutex sync.Mutex - Clock clockwork.Clock +// TODO: refactor to separate NodeManager file +type NodeManager interface { + GetNextNode() *Node + isRateLimitExceeded(node *Node, bodyLen int) bool + ResetLimits() + CheckHealth() } -func NewLoadBalancer(clock clockwork.Clock) *LoadBalancer { - return &LoadBalancer{ - Nodes: []*Node{}, - Next: 0, - Mutex: sync.Mutex{}, - Clock: clock, - } +type SafeNodeManager struct { + nodes []*Node + next int32 + clock clockwork.Clock } -func (lb *LoadBalancer) AddNode(node *Node) { - lb.Nodes = append(lb.Nodes, node) +type LoadBalancer struct { + manager NodeManager } -func (lb *LoadBalancer) GetNextNode() *Node { - lb.Mutex.Lock() - defer lb.Mutex.Unlock() +func NewSafeNodeManager(nodes []*Node, clock clockwork.Clock) *SafeNodeManager { + return &SafeNodeManager{ + nodes: nodes, + next: 0, + clock: clock, + } +} - node := lb.Nodes[lb.Next] - lb.Next = (lb.Next + 1) % len(lb.Nodes) +func (m *SafeNodeManager) GetNextNode() *Node { + idx := atomic.AddInt32(&m.next, 1) % int32(len(m.nodes)) + node := m.nodes[idx] - if node.Healthy { - return node + if node.Healthy == 1 { + return node } - log.Printf("Skipping unhealthy node %d (%s)", - node.ID, node.URL) + log.Printf("Skipping unhealthy node %d (%s)", node.ID, node.URL) return nil } -func isRateLimitExceeded(node *Node, bodyLen int) (bool) { - if node.ReqCount >= node.ReqLimit || node.BodyCount+bodyLen > node.BodyLimit { - return true - } - return false -} - -func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - bodyLen := int(r.ContentLength) - - for i := 0; i < len(lb.Nodes); i++ { - node := lb.GetNextNode() - - if node == nil { - continue - } - - node.Mutex.Lock() - defer node.Mutex.Unlock() - - if isRateLimitExceeded(node, bodyLen) { +func (m *SafeNodeManager) isRateLimitExceeded(node *Node, bodyLen int) bool { + if atomic.LoadUint32(&node.ReqCount) >= node.ReqLimit || + atomic.LoadUint64(&node.BodyCount)+uint64(bodyLen) > node.BodyLimit { log.Printf("Rate limit hit for node %d (%s) - RPM: %d/%d, BodyLimit: %d, RequestBody: %d\n", node.ID, node.URL, node.ReqCount, node.ReqLimit, node.BodyLimit, bodyLen) - - continue - } - - node.ReqCount++ - node.BodyCount += bodyLen - - log.Printf("Forwarding request to node %d (%s) - RPM: %d/%d, BPM: %d/%d\n", - node.ID, node.URL, node.ReqCount, node.ReqLimit, node.BodyCount, node.BodyLimit) - - node.ReverseProxy.ServeHTTP(w, r) - return + return true } - log.Println("No available node") - http.Error(w, "No available node", http.StatusServiceUnavailable) + atomic.AddUint32(&node.ReqCount, 1) + atomic.AddUint64(&node.BodyCount, uint64(bodyLen)) + + return false } -func (lb *LoadBalancer) StartPeriodicTasks() { - lb.Clock.AfterFunc(1*time.Minute, lb.resetLimits) - lb.Clock.AfterFunc(30*time.Second, lb.checkHealth) +func (m *SafeNodeManager) ResetLimits() { + for _, node := range m.nodes { + atomic.StoreUint32(&node.ReqCount, 0) + atomic.StoreUint64(&node.BodyCount, 0) + } } -func (lb *LoadBalancer) resetLimits() { - for _, node := range lb.Nodes { - node.Mutex.Lock() - node.ReqCount = 0 - node.BodyCount = 0 - node.Mutex.Unlock() +func (m *SafeNodeManager) CheckHealth() { + for _, node := range m.nodes { + go func(n *Node) { + resp, err := http.Get(n.URL + "/health") + + if err != nil || resp.StatusCode != http.StatusOK { + atomic.StoreUint32(&n.Healthy, 0) + log.Printf("Node %d (%s) is unhealthy\n", n.ID, n.URL) + } else { + atomic.StoreUint32(&n.Healthy, 1) + log.Printf("Node %d (%s) is healthy\n", n.ID, n.URL) + } + + if resp != nil { + resp.Body.Close() + } + }(node) } - lb.Clock.AfterFunc(1*time.Minute, lb.resetLimits) } -func (lb *LoadBalancer) checkHealth() { - for _, node := range lb.Nodes { - go func(n *Node) { - resp, err := http.Get(n.URL + "/health") - n.Mutex.Lock() +func NewLoadBalancer(manager NodeManager) *LoadBalancer { + return &LoadBalancer{ + manager: manager, + } +} - if err != nil || resp.StatusCode != http.StatusOK { - n.Healthy = false - log.Printf("Node %d (%s) is unhealthy\n", node.ID, node.URL) +func (lb *LoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bodyLen := int(r.ContentLength) + nodeLen := len(lb.manager.(*SafeNodeManager).nodes) - } else { - n.Healthy = true - log.Printf("Node %d (%s) is healthy\n", node.ID, node.URL) + var node *Node + for i := 0; i < nodeLen; i++ { + node = lb.manager.GetNextNode() + if node == nil { + continue } - n.Mutex.Unlock() - if resp != nil { - resp.Body.Close() + + if lb.manager.isRateLimitExceeded(node, bodyLen) { + continue } - }(node) + + log.Printf("Forwarding request to node %d (%s)\n - RPM: %d/%d, BPM: %d/%d\n", + node.ID, node.URL, node.ReqCount, node.ReqLimit, node.BodyCount, node.BodyLimit) + + node.ReverseProxy.ServeHTTP(w, r) + return; } - lb.Clock.AfterFunc(30*time.Second, lb.checkHealth) + + http.Error(w, "No available node", http.StatusServiceUnavailable) +} + +func (m *SafeNodeManager) StartPeriodicTasks() { + m.clock.AfterFunc(1*time.Minute, m.ResetLimits) + m.clock.AfterFunc(30*time.Second, m.CheckHealth) } diff --git a/main.go b/main.go index 3369345..d580c48 100644 --- a/main.go +++ b/main.go @@ -10,21 +10,21 @@ import ( ) func main() { - clock := clockwork.NewRealClock() - lb := NewLoadBalancer(clock) - - nodes := []NodeParams{ - {ID: 1, URL: "http://localhost:8081", ReqLimit: 2, BodyLimit: 123}, - {ID: 2, URL: "http://localhost:8082", ReqLimit: 5, BodyLimit: 2 * 1024 * 1024}, - {ID: 3, URL: "http://localhost:8083", ReqLimit: 7, BodyLimit: 1 * 1024 * 1024}, + nodeParams := []NodeParams{ + {ID: 1, URL: "http://localhost:8081", ReqLimit: 2, BodyLimit: 123}, + {ID: 2, URL: "http://localhost:8082", ReqLimit: 5, BodyLimit: 2 * 1024 * 1024}, + {ID: 3, URL: "http://localhost:8083", ReqLimit: 7, BodyLimit: 1 * 1024 * 1024}, } - for _, nodeParams := range nodes { - go startBackendServer(nodeParams.ID, nodeParams.URL) - lb.AddNode(NewNode(nodeParams)) + nodes := make([]*Node, len(nodeParams)) + for i, params := range nodeParams { + go startBackendServer(params.ID, params.URL) + nodes[i] = NewNode(params) } - lb.StartPeriodicTasks() + nodeManager := NewSafeNodeManager(nodes, clockwork.NewRealClock()) + nodeManager.StartPeriodicTasks() + lb := NewLoadBalancer(nodeManager) time.Sleep(1 * time.Second) diff --git a/main_test.go b/main_test.go index f9fbca1..fea3ae0 100644 --- a/main_test.go +++ b/main_test.go @@ -1,7 +1,9 @@ package main import ( + "bytes" "fmt" + "io" "net/http" "net/http/httptest" "sync" @@ -14,8 +16,8 @@ import ( var ( client = &http.Client{} lbServer *httptest.Server - lb = &LoadBalancer{} servers []*http.Server + nodeManager *SafeNodeManager nodeParams = []NodeParams{ {ID: 1, URL: "http://localhost:8081", ReqLimit: 2, BodyLimit: 76}, {ID: 2, URL: "http://localhost:8082", ReqLimit: 3, BodyLimit: 2 * 1024 * 1024}, @@ -23,13 +25,13 @@ var ( } ) -func startTestServer(port int) *http.Server { +func startTestServer(id int) *http.Server { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) + w.Write([]byte(fmt.Sprintf("node: %d", id))) }) server := &http.Server{ - Addr: fmt.Sprintf(":%d", port), + Addr: fmt.Sprintf(":%d", 8080 + id), Handler: handler, } go server.ListenAndServe() @@ -37,11 +39,16 @@ func startTestServer(port int) *http.Server { } func setup() { - for _, nodeParams := range nodeParams { - server := startTestServer(8080 + nodeParams.ID) - servers = append(servers, server) - lb.AddNode(NewNode(nodeParams)) + nodes := make([]*Node, len(nodeParams)) + for i, nodeParams := range nodeParams { + startTestServer(nodeParams.ID) + nodes[i] = NewNode(nodeParams) } + + nodeManager = NewSafeNodeManager(nodes, clockwork.NewFakeClock()) + nodeManager.StartPeriodicTasks() + lb := NewLoadBalancer(nodeManager) + lbServer = httptest.NewServer(lb) time.Sleep(1 * time.Second) } @@ -54,20 +61,13 @@ func teardown() { } } -func resetRateLimits(lb *LoadBalancer) { - for _, node := range lb.Nodes { - node.Mutex.Lock() - node.ReqCount = 0 - node.BodyCount = 0 - node.Mutex.Unlock() - } +func resetRateLimits(m *SafeNodeManager) { + m.ResetLimits() } -func resetHealthChecks(lb *LoadBalancer) { - for _, node := range lb.Nodes { - node.Mutex.Lock() - node.Healthy = true - node.Mutex.Unlock() +func resetHealthChecks(m *SafeNodeManager) { + for _, node := range m.nodes { + node.Healthy = 1 } } @@ -80,25 +80,28 @@ func TestLoadBalancer(t *testing.T) { t.Run("NodeHealthHandling", testAllNodesUnhealthy) t.Run("AllNodesBusy", testAllNodesBusy) t.Run("RateLimitReset", testRateLimitReset) + t.Run("RateLimitExceededByBPM", testRateLimitExceededByBPM) } // Should work by round robin func testRoundRobin(t *testing.T) { - next := lb.Next + next := 0 - for i := 0; i <= len(lb.Nodes) + 1; i++ { - node := lb.Nodes[next] + for i := 0; i <= len(nodeManager.nodes) + 1; i++ { + node := nodeManager.nodes[next] - if node.ID != next+1 { + if node.ID != next + 1 { t.Errorf("Round robin not working correctly, expected %d, got %d", next+1, node.ID) } - next = (next + 1) % len(lb.Nodes) + next = (next + 1) % len(nodeManager.nodes) } } // Should hit rate limit by requests per minute for each node func testRateLimitExceededByRPM(t *testing.T) { - resetRateLimits(lb) + resetRateLimits(nodeManager) + nodes := nodeManager.nodes + var wg sync.WaitGroup for i := 0; i < 8; i++ { wg.Add(1) @@ -114,23 +117,57 @@ func testRateLimitExceededByRPM(t *testing.T) { } wg.Wait() - if lb.Nodes[0].ReqCount < lb.Nodes[0].ReqLimit { - t.Errorf("Rate Limit should be hit; RPM %d/%d", lb.Nodes[0].ReqCount, lb.Nodes[0].ReqLimit) + if nodes[0].ReqCount < nodes[0].ReqLimit { + t.Errorf("Rate Limit should be hit; RPM %d/%d", nodes[0].ReqCount, nodes[0].ReqLimit) } - t.Logf("Node2 ReqCount:%d, ReqLimit:%d", lb.Nodes[1].ReqCount, lb.Nodes[1].ReqLimit) - t.Logf("Node3 ReqCount:%d, ReqLimit:%d", lb.Nodes[2].ReqCount, lb.Nodes[2].ReqLimit) + t.Logf("Node2 ReqCount:%d, ReqLimit:%d", nodes[1].ReqCount, nodes[1].ReqLimit) + t.Logf("Node3 ReqCount:%d, ReqLimit:%d", nodes[2].ReqCount, nodes[2].ReqLimit) } -// TODO: Should hit rate limit by request body size per minute for each node +// Should hit rate limit by request body size per minute for each node func testRateLimitExceededByBPM(t *testing.T) { + resetRateLimits(nodeManager) + resetHealthChecks(nodeManager) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := client.Post(lbServer.URL, "text/plain", bytes.NewBufferString("this should exceed the body limit!! this sentence is under 76 bytes")) + if err != nil { + t.Errorf("Failed to send request: %v", err) + return + } + resp.Body.Close() + }() + } + wg.Wait() + + resp, err := client.Post(lbServer.URL, "text/plain", bytes.NewBufferString("this should skip the node1")) + if err != nil { + t.Errorf("Failed to send request: %v", err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Failed to read response body: %v", err) + } + + expected := "node: 2" + if string(body) != "node: 2" { + t.Errorf("Expected response body to be %s, got %s", expected, body) + } } // Should return 503 when all nodes are unhealthy func testAllNodesUnhealthy(t *testing.T) { - resetRateLimits(lb) + resetRateLimits(nodeManager) - for _, node := range lb.Nodes { - node.Healthy = false + for _, node := range nodeManager.nodes { + node.Healthy = 0 } resp, err := client.Get(lbServer.URL) @@ -144,16 +181,15 @@ func testAllNodesUnhealthy(t *testing.T) { // Should return 503 when all nodes have hit rate limits func testAllNodesBusy(t *testing.T) { - resetHealthChecks(lb) + resetHealthChecks(nodeManager) - for _, node := range lb.Nodes { + for _, node := range nodeManager.nodes { node.ReqCount = node.ReqLimit } resp, err := client.Get(lbServer.URL) if err != nil && resp.StatusCode != http.StatusServiceUnavailable { t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) - t.Logf("Err: %v", err) } resp.Body.Close() } @@ -162,9 +198,8 @@ func testAllNodesBusy(t *testing.T) { func testRateLimitReset(t *testing.T) { c := clockwork.NewFakeClock() - lb2 := NewLoadBalancer(c) - node := NewNode(nodeParams[0]) - lb2.AddNode(node) + nodeManager := NewSafeNodeManager(nodeManager.nodes, c) + node := nodeManager.nodes[0] node.ReqCount = 10 node.BodyCount = 100 @@ -172,7 +207,7 @@ func testRateLimitReset(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - lb2.StartPeriodicTasks() + nodeManager.StartPeriodicTasks() c.Advance(61 * time.Second) wg.Done() }() diff --git a/node.go b/node.go index aeb9126..334e2a2 100644 --- a/node.go +++ b/node.go @@ -3,25 +3,23 @@ package main import ( "net/http/httputil" "net/url" - "sync" ) type NodeParams struct { ID int URL string - ReqLimit int - BodyLimit int + ReqLimit uint32 + BodyLimit uint64 } type Node struct { ID int URL string - ReqLimit int - BodyLimit int - ReqCount int - BodyCount int - Healthy bool - Mutex sync.Mutex + ReqLimit uint32 + BodyLimit uint64 + ReqCount uint32 + BodyCount uint64 + Healthy uint32 ReverseProxy *httputil.ReverseProxy } @@ -34,15 +32,8 @@ func NewNode(params NodeParams) *Node { URL: params.URL, ReqLimit: params.ReqLimit, BodyLimit: params.BodyLimit, - Healthy: true, + Healthy: 1, ReverseProxy: rp, } return node } - -func (n *Node) ResetLimits() { - n.Mutex.Lock() - defer n.Mutex.Unlock() - n.ReqCount = 0 - n.BodyCount = 0 -}