From 214816fc795a0f56ae34c4b76ca63a7e8d95e2ee Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Mon, 9 Dec 2024 21:46:47 -0500 Subject: [PATCH] Update lb behavior tests to include both strategies --- internal/loadbalancer/load_balancer_test.go | 56 ++++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/internal/loadbalancer/load_balancer_test.go b/internal/loadbalancer/load_balancer_test.go index 64d499a0..d158d526 100644 --- a/internal/loadbalancer/load_balancer_test.go +++ b/internal/loadbalancer/load_balancer_test.go @@ -11,7 +11,7 @@ import ( "github.com/substratusai/kubeai/internal/apiutils" ) -func TestAwaitBestHost(t *testing.T) { +func TestAwaitBestHostBehavior(t *testing.T) { const ( myModel = "my-model" myAdapter = "my-adapter" @@ -61,31 +61,37 @@ func TestAwaitBestHost(t *testing.T) { } for name, spec := range testCases { - t.Run(name, func(t *testing.T) { - manager := &LoadBalancer{ - groups: make(map[string]*group, 1), - } + // Behavior in these tests should be the same for both strategies. + for _, strategy := range []v1.LoadBalancingStrategy{ + v1.LeastLoadStrategy, + v1.PrefixHashStrategy, + } { + t.Run(name+" with "+string(strategy)+" strategy", func(t *testing.T) { + manager := &LoadBalancer{ + groups: make(map[string]*group, 1), + } - manager.getEndpoints(myModel).reconcileEndpoints(spec.endpoints) + manager.getEndpoints(myModel).reconcileEndpoints(spec.endpoints) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() - gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, &apiutils.Request{ - Model: spec.model, - Adapter: spec.adapter, - LoadBalancing: v1.LoadBalancing{ - Strategy: v1.LeastLoadStrategy, - }, + gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, &apiutils.Request{ + Model: spec.model, + Adapter: spec.adapter, + LoadBalancing: v1.LoadBalancing{ + Strategy: strategy, + }, + }) + if spec.expErr != nil { + require.ErrorIs(t, spec.expErr, gotErr) + return + } + require.NoError(t, gotErr) + gotFunc() + assert.Equal(t, spec.expAddr, gotAddr) }) - if spec.expErr != nil { - require.ErrorIs(t, spec.expErr, gotErr) - return - } - require.NoError(t, gotErr) - gotFunc() - assert.Equal(t, spec.expAddr, gotAddr) - }) + } } } @@ -136,7 +142,7 @@ func TestLoadBalancingStrategies(t *testing.T) { steps []testStep }{ { - name: "2 models, 2 pods each, least load strategy", + name: "least load strategy", modelEndpoints: map[string]map[string]endpoint{ modelA: { podA1Name: {address: podA1Addr, adapters: map[string]struct{}{adapterA1: {}}}, @@ -223,7 +229,7 @@ func TestLoadBalancingStrategies(t *testing.T) { }, }, { - name: "1 model, 2 pods, each with 10 requests to start, prefix hash strategy", + name: "prefix hash strategy", modelEndpoints: map[string]map[string]endpoint{ modelA: { podA1Name: {address: podA1Addr}, @@ -248,7 +254,7 @@ func TestLoadBalancingStrategies(t *testing.T) { }, steps: []testStep{ { - name: "first request to model-a", + name: "first request to model-a, preferring pod-a-1, each pod has 10 in-flight requests", model: modelA, prefix: podA1Hash, requestCount: 1,