From b8cdab171d1fda053ab18348342eb30c27e84662 Mon Sep 17 00:00:00 2001 From: Matt Siwiec Date: Fri, 3 Nov 2023 12:00:50 -0600 Subject: [PATCH] Hooks improvement to use ent eager loading (#261) * additional hooks unit tests Signed-off-by: Matt Siwiec * audit loadbalancer manual hook db hits and ensure providerID included in additionalSubjects Signed-off-by: Matt Siwiec * audit origin manual hook db hits Signed-off-by: Matt Siwiec * audit pool manual hook db hits Signed-off-by: Matt Siwiec * audit port manual hook db hits Signed-off-by: Matt Siwiec * bump test event msg channel timeout to 5s; align Signed-off-by: Matt Siwiec * test with one retry Signed-off-by: Matt Siwiec * tweek flake Signed-off-by: Matt Siwiec --------- Signed-off-by: Matt Siwiec --- .github/workflows/test-go.yml | 5 +- internal/manualhooks/hooks.go | 270 +++++++++++------------------ internal/manualhooks/hooks_test.go | 106 ++++++++++- internal/testutils/db_setup.go | 7 + 4 files changed, 207 insertions(+), 181 deletions(-) diff --git a/.github/workflows/test-go.yml b/.github/workflows/test-go.yml index 2dbedfc45..731d74fad 100644 --- a/.github/workflows/test-go.yml +++ b/.github/workflows/test-go.yml @@ -40,5 +40,8 @@ jobs: - name: Install atlas for db migrations on ${{ matrix.ci-database }} run: go install ariga.io/atlas/cmd/atlas@latest + # with one retry - name: Run go tests for ${{ matrix.ci-database }} - run: LOADBALANCERAPI_TESTDB_URI="${{ matrix.env-database-uri }}" go test -race -coverprofile=coverage.txt -covermode=atomic -tags testtools ./... + run: | + LOADBALANCERAPI_TESTDB_URI="${{ matrix.env-database-uri }}" go test -race -coverprofile=coverage.txt -covermode=atomic -tags testtools ./... || \ + LOADBALANCERAPI_TESTDB_URI="${{ matrix.env-database-uri }}" go test -race -coverprofile=coverage.txt -covermode=atomic -tags testtools ./... diff --git a/internal/manualhooks/hooks.go b/internal/manualhooks/hooks.go index 5af82ff50..e770b094f 100644 --- a/internal/manualhooks/hooks.go +++ b/internal/manualhooks/hooks.go @@ -196,24 +196,21 @@ func LoadBalancerHooks() []ent.Hook { return retValue, err } - addSubjPortIDs, err := m.Client().Port.Query().Where(port.HasLoadBalancerWith(loadbalancer.IDEQ(objID))).IDs(ctx) + // Ensure we have additional relevant subjects in the msg + lb, err := m.Client().LoadBalancer.Query().WithPorts().Where(loadbalancer.IDEQ(objID)).Only(ctx) if err == nil { - for _, portID := range addSubjPortIDs { - if !slices.Contains(msg.AdditionalSubjectIDs, portID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, portID) - } + if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) } - } - lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) + if !slices.Contains(msg.AdditionalSubjectIDs, lb.ProviderID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ProviderID) } - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) + for _, p := range lb.Edges.Ports { + if !slices.Contains(msg.AdditionalSubjectIDs, p.ID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, p.ID) + } } } @@ -251,18 +248,8 @@ func LoadBalancerHooks() []ent.Hook { } additionalSubjects = append(additionalSubjects, dbObj.OwnerID) - - lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(additionalSubjects, lb.LocationID) { - additionalSubjects = append(additionalSubjects, lb.LocationID) - } - } + additionalSubjects = append(additionalSubjects, dbObj.LocationID) + additionalSubjects = append(additionalSubjects, dbObj.ProviderID) // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) @@ -491,37 +478,31 @@ func OriginHooks() []ent.Hook { return retValue, err } - addSubjPools, err := m.Client().Pool.Query().Where(pool.HasOriginsWith(origin.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the msg + addSubjPorts, err := m.Client().Port.Query().WithPools().WithLoadBalancer().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx) if err == nil { - for _, pool := range addSubjPools { - if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID) + for _, port := range addSubjPorts { + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) } - if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID) + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) } - } - } - addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx) - if err == nil { - for _, port := range addSubjPorts { if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) { msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID) } - } - } - lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } + for _, pool := range port.Edges.Pools { + if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID) + } - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) + if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID) + } + } } } @@ -560,25 +541,31 @@ func OriginHooks() []ent.Hook { additionalSubjects = append(additionalSubjects, dbObj.PoolID) - addSubjPools, err := m.Client().Pool.Query().Where(pool.HasOriginsWith(origin.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the msg + addSubjPorts, err := m.Client().Port.Query().WithPools().WithLoadBalancer().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx) if err == nil { - for _, pool := range addSubjPools { - if !slices.Contains(additionalSubjects, pool.ID) && objID != pool.ID { - additionalSubjects = append(additionalSubjects, pool.ID) - } - - if !slices.Contains(additionalSubjects, pool.OwnerID) { - additionalSubjects = append(additionalSubjects, pool.OwnerID) + for _, port := range addSubjPorts { + for _, pool := range port.Edges.Pools { + if !slices.Contains(additionalSubjects, pool.ID) { + additionalSubjects = append(additionalSubjects, pool.ID) + } + + if !slices.Contains(additionalSubjects, pool.OwnerID) { + additionalSubjects = append(additionalSubjects, pool.OwnerID) + } } - } - } - addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx) - if err == nil { - for _, port := range addSubjPorts { if !slices.Contains(additionalSubjects, port.LoadBalancerID) { additionalSubjects = append(additionalSubjects, port.LoadBalancerID) } + + if !slices.Contains(additionalSubjects, port.Edges.LoadBalancer.LocationID) { + additionalSubjects = append(additionalSubjects, port.Edges.LoadBalancer.LocationID) + } + + if !slices.Contains(additionalSubjects, port.Edges.LoadBalancer.ProviderID) { + additionalSubjects = append(additionalSubjects, port.Edges.LoadBalancer.ProviderID) + } } } @@ -599,18 +586,6 @@ func OriginHooks() []ent.Hook { } } - lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(additionalSubjects, lb.LocationID) { - additionalSubjects = append(additionalSubjects, lb.LocationID) - } - } - msg := events.ChangeMessage{ EventType: eventType(m.Op()), SubjectID: objID, @@ -782,9 +757,20 @@ func PoolHooks() []ent.Hook { return retValue, err } - addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the msg + addSubjPorts, err := m.Client().Port.Query().WithLoadBalancer().WithPools(func(q *generated.PoolQuery) { + q.WithOrigins() + }).Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx) if err == nil { for _, port := range addSubjPorts { + for _, pool := range port.Edges.Pools { + for _, origin := range pool.Edges.Origins { + if !slices.Contains(msg.AdditionalSubjectIDs, origin.ID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, origin.ID) + } + } + } + if !slices.Contains(msg.AdditionalSubjectIDs, port.ID) && objID != port.ID { msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.ID) } @@ -792,34 +778,17 @@ func PoolHooks() []ent.Hook { if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) { msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID) } - } - } - addSubjOrigins, err := m.Client().Origin.Query().Where(origin.HasPoolWith(pool.IDEQ(objID))).All(ctx) - if err == nil { - for _, origin := range addSubjOrigins { - if !slices.Contains(msg.AdditionalSubjectIDs, origin.ID) && objID != origin.ID { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, origin.ID) + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) } - if !slices.Contains(msg.AdditionalSubjectIDs, origin.PoolID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, origin.PoolID) + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) } } } - lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) - } - } - if len(relationships) != 0 { if err := permissions.CreateAuthRelationships(ctx, "load-balancer-pool", objID, relationships...); err != nil { return nil, fmt.Errorf("relationship request failed with error: %w", err) @@ -855,9 +824,18 @@ func PoolHooks() []ent.Hook { additionalSubjects = append(additionalSubjects, dbObj.OwnerID) - addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the msg + addSubjPorts, err := m.Client().Port.Query().WithLoadBalancer().Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx) if err == nil { for _, port := range addSubjPorts { + if !slices.Contains(additionalSubjects, port.Edges.LoadBalancer.LocationID) { + additionalSubjects = append(additionalSubjects, port.Edges.LoadBalancer.LocationID) + } + + if !slices.Contains(additionalSubjects, port.Edges.LoadBalancer.ProviderID) { + additionalSubjects = append(additionalSubjects, port.Edges.LoadBalancer.ProviderID) + } + if !slices.Contains(additionalSubjects, port.LoadBalancerID) { additionalSubjects = append(additionalSubjects, port.LoadBalancerID) } @@ -869,18 +847,6 @@ func PoolHooks() []ent.Hook { SubjectID: dbObj.OwnerID, }) - lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(additionalSubjects, lb.LocationID) { - additionalSubjects = append(additionalSubjects, lb.LocationID) - } - } - // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) if err != nil { @@ -1050,9 +1016,30 @@ func PortHooks() []ent.Hook { return retValue, err } - addSubjPools, err := m.Client().Pool.Query().Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the event msg + addSubjPools, err := m.Client().Pool.Query().WithPorts(func(q *generated.PortQuery) { + q.WithLoadBalancer() + }).Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx) if err == nil { for _, pool := range addSubjPools { + for _, port := range pool.Edges.Ports { + if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID) + } + + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) + } + + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) + } + + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID) + } + } + if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID { msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID) } @@ -1062,38 +1049,6 @@ func PortHooks() []ent.Hook { } } } - addSubjLoadBalancers, err := m.Client().LoadBalancer.Query().Where(loadbalancer.HasPortsWith(port.IDEQ(objID))).All(ctx) - if err == nil { - for _, lb := range addSubjLoadBalancers { - if !slices.Contains(msg.AdditionalSubjectIDs, lb.ID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ID) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.OwnerID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.OwnerID) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.ProviderID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ProviderID) - } - } - } - - lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) - } - } if len(relationships) != 0 { if err := permissions.CreateAuthRelationships(ctx, "load-balancer-port", objID, relationships...); err != nil { @@ -1123,12 +1078,16 @@ func PortHooks() []ent.Hook { return nil, fmt.Errorf("object doesn't have an id %s", objID) } - dbObj, err := m.Client().Port.Get(ctx, objID) + dbObj, err := m.Client().Port.Query().WithLoadBalancer().Where(port.IDEQ(objID)).Only(ctx) if err != nil { return nil, fmt.Errorf("failed to load object to get values for event, err %w", err) } + // Ensure we have additional relevant subjects in the event msg additionalSubjects = append(additionalSubjects, dbObj.LoadBalancerID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.LocationID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.OwnerID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.ProviderID) // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) @@ -1142,37 +1101,6 @@ func PortHooks() []ent.Hook { } } - addSubjLoadBalancer, err := m.Client().LoadBalancer.Get(ctx, dbObj.LoadBalancerID) - if err == nil { - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.LocationID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.LocationID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.OwnerID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.OwnerID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ProviderID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ProviderID) - } - } - - lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(additionalSubjects, lb.LocationID) { - additionalSubjects = append(additionalSubjects, lb.LocationID) - } - } - msg := events.ChangeMessage{ EventType: eventType(m.Op()), SubjectID: objID, diff --git a/internal/manualhooks/hooks_test.go b/internal/manualhooks/hooks_test.go index bd334291c..4b6a9020b 100644 --- a/internal/manualhooks/hooks_test.go +++ b/internal/manualhooks/hooks_test.go @@ -18,7 +18,7 @@ import ( const ( ownerPrefix = "testown" locationPrefix = "testloc" - defaultTimeout = 2 * time.Second + defaultTimeout = 5 * time.Second ) var ( @@ -56,7 +56,7 @@ func Test_LoadbalancerCreateHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -81,7 +81,7 @@ func Test_LoadbalancerUpdateHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -106,7 +106,7 @@ func Test_LoadbalancerDeleteHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.OwnerID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.OwnerID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -133,7 +133,7 @@ func Test_OriginCreateHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -161,7 +161,7 @@ func Test_OriginUpdateHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -189,7 +189,7 @@ func Test_OriginDeleteHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -240,7 +240,7 @@ func Test_PoolUpdateHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, origin.ID, port.ID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID, origin.ID, port.ID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -267,7 +267,7 @@ func Test_PoolDeleteHook(t *testing.T) { msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) // Assert - expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.OwnerID, lb.ID, lb.LocationID} + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID} actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) @@ -354,3 +354,91 @@ func Test_PortDeleteHook(t *testing.T) { assert.Equal(t, port.ID, msg.Message().SubjectID) assert.Equal(t, deleteEventType, msg.Message().EventType) } + +func Test_MultipleLoadbalancersSharedPoolAddOrigin(t *testing.T) { + // Scenario: 2 loadbalancers in different locations, with the same owner, share a pool. + // An origin is added to the shared pool. + // Assert the owner, loadbalancers, pool, and locations are all included in the additionalSubject ID list + + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "create.load-balancer-origin") + require.NoError(t, err, "failed to subscribe to changes") + + // create 2 loadbalancers with a shared pool of origins + prov := (&testutils.ProviderBuilder{}).MustNew(ctx) + lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx) + _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx) + _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx) + + testutils.EntClient.Origin.Use(manualhooks.OriginHooks()...) + + // Act - add another origin to the pool + ogn := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{ + prov.ID, + lb1.OwnerID, + lb1.ID, + lb2.ID, + lb1.LocationID, + lb2.LocationID, + pool.ID, + } + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, ogn.ID, msg.Message().SubjectID) + assert.Equal(t, createEventType, msg.Message().EventType) +} + +func Test_MultipleLoadbalancersSharedPoolDeleteOrigin(t *testing.T) { + // Scenario: 2 loadbalancers in different locations, with the same owner, share a pool. + // An origin is removed from the shared pool. + // Assert the owner, loadbalancers, pool, and locations are all included in the additionalSubject ID list + + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "delete.load-balancer-origin") + require.NoError(t, err, "failed to subscribe to changes") + + // create 2 loadbalancers with a shared pool of origins + prov := (&testutils.ProviderBuilder{}).MustNew(ctx) + lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx) + _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx) + _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx) + _ = (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + ogn2 := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + testutils.EntClient.Origin.Use(manualhooks.OriginHooks()...) + + // Act - update the pool to remove an origin + testutils.EntClient.Origin.DeleteOne(ogn2).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{ + prov.ID, + lb1.OwnerID, + lb1.ID, + lb2.ID, + lb1.LocationID, + lb2.LocationID, + pool.ID, + } + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, ogn2.ID, msg.Message().SubjectID) + assert.Equal(t, deleteEventType, msg.Message().EventType) +} diff --git a/internal/testutils/db_setup.go b/internal/testutils/db_setup.go index 64631d039..feeb9cffb 100644 --- a/internal/testutils/db_setup.go +++ b/internal/testutils/db_setup.go @@ -23,6 +23,7 @@ import ( var ( testDBURI = os.Getenv("LOADBALANCERAPI_TESTDB_URI") + NATSConn *eventtools.TestNats // NATSConn exported if needed for subscribers EventsConn events.Connection // EventsConn exported if needed for subscribers EntClient *ent.Client // EntClient to use as ent client DBContainer *testcontainersx.DBContainer // DBContainer to use through entire test suite @@ -37,6 +38,7 @@ func SetupDB() { IfErrPanic("failed to start nats server", err) conn, err := events.NewConnection(nats.Config) + IfErrPanic("failed to create events connection", err) // DB and EntClient setup @@ -61,6 +63,7 @@ func SetupDB() { EventsConn = conn EntClient = c DBContainer = cntr + NATSConn = nats } // TeardownDB used for clean up test setup @@ -74,6 +77,10 @@ func TeardownDB() { if DBContainer != nil && DBContainer.Container.IsRunning() { IfErrPanic("teardown failed to terminate test db container", DBContainer.Container.Terminate(ctx)) } + + _ = EventsConn.Shutdown(ctx) + + NATSConn.Close() } // ParseDBURI parses the kind of query language from TESTDB_URI env var and initializes DBContainer as required