diff --git a/internal/graphapi/tools_test.go b/internal/graphapi/tools_test.go index c49384318..ad7f8cb87 100644 --- a/internal/graphapi/tools_test.go +++ b/internal/graphapi/tools_test.go @@ -2,35 +2,26 @@ package graphapi_test import ( "context" - "log" "net/http" "net/http/httptest" "os" - "path/filepath" - "strings" "testing" - "entgo.io/ent/dialect" "github.com/99designs/gqlgen/graphql/handler" "github.com/labstack/echo/v4" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "github.com/testcontainers/testcontainers-go/modules/postgres" "go.uber.org/zap" "go.infratographer.com/permissions-api/pkg/permissions" "go.infratographer.com/x/echojwtx" "go.infratographer.com/x/echox" - "go.infratographer.com/x/events" - "go.infratographer.com/x/goosex" - "go.infratographer.com/x/testing/eventtools" - "go.infratographer.com/load-balancer-api/db" ent "go.infratographer.com/load-balancer-api/internal/ent/generated" "go.infratographer.com/load-balancer-api/internal/graphapi" "go.infratographer.com/load-balancer-api/internal/graphclient" "go.infratographer.com/load-balancer-api/internal/manualhooks" - "go.infratographer.com/load-balancer-api/x/testcontainersx" + "go.infratographer.com/load-balancer-api/internal/testutils" ) const ( @@ -39,115 +30,26 @@ const ( lbPrefix = "loadbal" ) -var ( - TestDBURI = os.Getenv("LOADBALANCERAPI_TESTDB_URI") - EntClient *ent.Client - DBContainer *testcontainersx.DBContainer -) +var EntClient *ent.Client func TestMain(m *testing.M) { - // setup the database if needed - setupDB() - // run the tests - code := m.Run() - // teardown the database - teardownDB() - // return the test response code - os.Exit(code) -} - -func parseDBURI(ctx context.Context) (string, string, *testcontainersx.DBContainer) { - switch { - // if you don't pass in a database we default to an in memory sqlite - case TestDBURI == "": - return dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1", nil - case strings.HasPrefix(TestDBURI, "sqlite://"): - return dialect.SQLite, strings.TrimPrefix(TestDBURI, "sqlite://"), nil - case strings.HasPrefix(TestDBURI, "postgres://"), strings.HasPrefix(TestDBURI, "postgresql://"): - return dialect.Postgres, TestDBURI, nil - case strings.HasPrefix(TestDBURI, "docker://"): - dbImage := strings.TrimPrefix(TestDBURI, "docker://") - - switch { - case strings.HasPrefix(dbImage, "cockroach"), strings.HasPrefix(dbImage, "cockroachdb"), strings.HasPrefix(dbImage, "crdb"): - cntr, err := testcontainersx.NewCockroachDB(ctx, dbImage) - errPanic("error starting db test container", err) - - return dialect.Postgres, cntr.URI, cntr - case strings.HasPrefix(dbImage, "postgres"): - cntr, err := testcontainersx.NewPostgresDB(ctx, dbImage, - postgres.WithInitScripts(filepath.Join("testdata", "postgres_init.sh")), - ) - errPanic("error starting db test container", err) - - return dialect.Postgres, cntr.URI, cntr - default: - panic("invalid testcontainer URI, uri: " + TestDBURI) - } - - default: - panic("invalid DB URI, uri: " + TestDBURI) - } -} - -func setupDB() { - // don't setup the datastore if we already have one - if EntClient != nil { - return - } - - ctx := context.Background() + // setup the database + testutils.SetupDB() - dia, uri, cntr := parseDBURI(ctx) + // assign package variables + EntClient = testutils.EntClient - nats, err := eventtools.NewNatsServer() - if err != nil { - errPanic("failed to start nats server", err) - } - - conn, err := events.NewConnection(nats.Config) - if err != nil { - errPanic("failed to create events publisher", err) - } + // setup the resolver hooks + manualhooks.PubsubHooks(EntClient) - c, err := ent.Open(dia, uri, ent.Debug(), ent.EventsPublisher(conn)) - if err != nil { - errPanic("failed terminating test db container after failing to connect to the db", cntr.Container.Terminate(ctx)) - errPanic("failed opening connection to database:", err) - } - - switch dia { - case dialect.SQLite: - // Run automatic migrations for SQLite - errPanic("failed creating db scema", c.Schema.Create(ctx)) - case dialect.Postgres: - log.Println("Running database migrations") - goosex.MigrateUp(uri, db.Migrations) - } - - // TODO: fix generated pubsubhooks - // pubsubhooks.PubsubHooks(c) - manualhooks.PubsubHooks(c) - - EntClient = c -} - -func teardownDB() { - ctx := context.Background() - - if EntClient != nil { - errPanic("teardown failed to close database connection", EntClient.Close()) - } + // run the tests + code := m.Run() - if DBContainer != nil { - errPanic("teardown failed to terminate test db container", DBContainer.Container.Terminate(ctx)) - } -} + // teardown the database + testutils.TeardownDB() -func errPanic(msg string, err error) { - if err != nil { - log.Panicf("%s err: %s", msg, err.Error()) - } + // return the test response code + os.Exit(code) } type graphClient struct { diff --git a/internal/manualhooks/hooks_test.go b/internal/manualhooks/hooks_test.go new file mode 100644 index 000000000..bd334291c --- /dev/null +++ b/internal/manualhooks/hooks_test.go @@ -0,0 +1,356 @@ +package manualhooks_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.infratographer.com/x/events" + "go.infratographer.com/x/gidx" + + "go.infratographer.com/load-balancer-api/internal/manualhooks" + "go.infratographer.com/load-balancer-api/internal/testutils" +) + +const ( + ownerPrefix = "testown" + locationPrefix = "testloc" + defaultTimeout = 2 * time.Second +) + +var ( + createEventType = string(events.CreateChangeType) + updateEventType = string(events.UpdateChangeType) + deleteEventType = string(events.DeleteChangeType) +) + +func TestMain(m *testing.M) { + // setup the database + testutils.SetupDB() + + // run the tests + code := m.Run() + + // teardown the database + testutils.TeardownDB() + + // return the test response code + os.Exit(code) +} + +func Test_LoadbalancerCreateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "create.load-balancer") + require.NoError(t, err, "failed to subscribe to changes") + + testutils.EntClient.LoadBalancer.Use(manualhooks.LoadBalancerHooks()...) + + // Act + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, lb.ID, msg.Message().SubjectID) + assert.Equal(t, createEventType, msg.Message().EventType) +} + +func Test_LoadbalancerUpdateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "update.load-balancer") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + + testutils.EntClient.LoadBalancer.Use(manualhooks.LoadBalancerHooks()...) + + // Act + testutils.EntClient.LoadBalancer.UpdateOne(lb).SetName(("other-lb-name")).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.ID, lb.OwnerID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, lb.ID, msg.Message().SubjectID) + assert.Equal(t, updateEventType, msg.Message().EventType) +} + +func Test_LoadbalancerDeleteHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "delete.load-balancer") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + + testutils.EntClient.LoadBalancer.Use(manualhooks.LoadBalancerHooks()...) + + // Act + testutils.EntClient.LoadBalancer.DeleteOneID(lb.ID).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.OwnerID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, lb.ID, msg.Message().SubjectID) + assert.Equal(t, deleteEventType, msg.Message().EventType) +} + +func Test_OriginCreateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "create.load-balancer-origin") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + + testutils.EntClient.Origin.Use(manualhooks.OriginHooks()...) + + // Act + origin := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, origin.ID, msg.Message().SubjectID) + assert.Equal(t, createEventType, msg.Message().EventType) +} + +func Test_OriginUpdateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "update.load-balancer-origin") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + origin := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + testutils.EntClient.Origin.Use(manualhooks.OriginHooks()...) + + // Act + testutils.EntClient.Origin.UpdateOne(origin).SetName("other-origin-name").ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, origin.ID, msg.Message().SubjectID) + assert.Equal(t, updateEventType, msg.Message().EventType) +} + +func Test_OriginDeleteHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "delete.load-balancer-origin") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + origin := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + testutils.EntClient.Origin.Use(manualhooks.OriginHooks()...) + + // Act + testutils.EntClient.Origin.DeleteOne(origin).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, origin.ID, msg.Message().SubjectID) + assert.Equal(t, deleteEventType, msg.Message().EventType) +} + +func Test_PoolCreateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "create.load-balancer-pool") + require.NoError(t, err, "failed to subscribe to changes") + + testutils.EntClient.Pool.Use(manualhooks.PoolHooks()...) + + // Act + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.OwnerID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, pool.ID, msg.Message().SubjectID) + assert.Equal(t, createEventType, msg.Message().EventType) +} + +func Test_PoolUpdateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "update.load-balancer-pool") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + port := (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + origin := (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) + + testutils.EntClient.Pool.Use(manualhooks.PoolHooks()...) + + // Act + testutils.EntClient.Pool.UpdateOne(pool).SetName("other-pool-name").ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, origin.ID, port.ID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, pool.ID, msg.Message().SubjectID) + assert.Equal(t, updateEventType, msg.Message().EventType) +} + +func Test_PoolDeleteHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "delete.load-balancer-pool") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + + testutils.EntClient.Pool.Use(manualhooks.PoolHooks()...) + + // Act + testutils.EntClient.Pool.DeleteOne(pool).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.OwnerID, lb.ID, lb.LocationID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, pool.ID, msg.Message().SubjectID) + assert.Equal(t, deleteEventType, msg.Message().EventType) +} + +func Test_PortCreateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "create.load-balancer-port") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + + testutils.EntClient.Port.Use(manualhooks.PortHooks()...) + + // Act + port := (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID, lb.OwnerID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, port.ID, msg.Message().SubjectID) + assert.Equal(t, createEventType, msg.Message().EventType) +} + +func Test_PortUpdateHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "update.load-balancer-port") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + port := (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + + testutils.EntClient.Port.Use(manualhooks.PortHooks()...) + + // Act + testutils.EntClient.Port.UpdateOne(port).SetName("other-port-name").ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{pool.ID, pool.OwnerID, lb.ID, lb.LocationID, lb.ProviderID, lb.OwnerID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, port.ID, msg.Message().SubjectID) + assert.Equal(t, updateEventType, msg.Message().EventType) +} + +func Test_PortDeleteHook(t *testing.T) { + // Arrange + ctx := testutils.MockPermissions(context.Background()) + + changesChannel, err := testutils.EventsConn.SubscribeChanges(ctx, "delete.load-balancer-port") + require.NoError(t, err, "failed to subscribe to changes") + + lb := (&testutils.LoadBalancerBuilder{}).MustNew(ctx) + pool := (&testutils.PoolBuilder{}).MustNew(ctx) + port := (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb.ID}).MustNew(ctx) + + testutils.EntClient.Port.Use(manualhooks.PortHooks()...) + + // Act + testutils.EntClient.Port.DeleteOne(port).ExecX(ctx) + + msg := testutils.ChannelReceiveWithTimeout[events.Message[events.ChangeMessage]](t, changesChannel, defaultTimeout) + + // Assert + expectedAdditionalSubjectIDs := []gidx.PrefixedID{lb.OwnerID, lb.ID, lb.LocationID, lb.ProviderID} + actualAdditionalSubjectIDs := msg.Message().AdditionalSubjectIDs + + assert.ElementsMatch(t, expectedAdditionalSubjectIDs, actualAdditionalSubjectIDs) + assert.Equal(t, port.ID, msg.Message().SubjectID) + assert.Equal(t, deleteEventType, msg.Message().EventType) +} diff --git a/internal/testutils/db_setup.go b/internal/testutils/db_setup.go new file mode 100644 index 000000000..64631d039 --- /dev/null +++ b/internal/testutils/db_setup.go @@ -0,0 +1,115 @@ +package testutils + +import ( + "context" + "log" + "os" + "path/filepath" + "runtime" + "strings" + + "entgo.io/ent/dialect" + _ "github.com/lib/pq" // used by the ent client using ParseDBURI return values + _ "github.com/mattn/go-sqlite3" // used by the ent client using ParseDBURI return values + "github.com/testcontainers/testcontainers-go/modules/postgres" + "go.infratographer.com/x/events" + "go.infratographer.com/x/goosex" + "go.infratographer.com/x/testing/eventtools" + + "go.infratographer.com/load-balancer-api/db" + ent "go.infratographer.com/load-balancer-api/internal/ent/generated" + "go.infratographer.com/load-balancer-api/x/testcontainersx" +) + +var ( + testDBURI = os.Getenv("LOADBALANCERAPI_TESTDB_URI") + EventsConn events.Connection // EventsConn exported if needed for subscribers + EntClient *ent.Client // EntClient to use as ent client + DBContainer *testcontainersx.DBContainer // DBContainer to use through entire test suite +) + +// SetupDB sets up in-memory nats server/conn, database and ent client to interact with db +func SetupDB() { + ctx := context.Background() + + // NATS setup + nats, err := eventtools.NewNatsServer() + IfErrPanic("failed to start nats server", err) + + conn, err := events.NewConnection(nats.Config) + IfErrPanic("failed to create events connection", err) + + // DB and EntClient setup + dia, uri, cntr := ParseDBURI(ctx) + + c, err := ent.Open(dia, uri, ent.Debug(), ent.EventsPublisher(conn)) + if err != nil { + log.Println(err) + IfErrPanic("failed terminating test db container after failing to connect to the db", cntr.Container.Terminate(ctx)) + IfErrPanic("failed opening connection to database:", err) + } + + switch dia { + case dialect.SQLite: + // Run automatic migrations for SQLite + IfErrPanic("failed creating db schema", c.Schema.Create(ctx)) + case dialect.Postgres: + log.Println("Running database migrations") + goosex.MigrateUp(uri, db.Migrations) + } + + EventsConn = conn + EntClient = c + DBContainer = cntr +} + +// TeardownDB used for clean up test setup +func TeardownDB() { + ctx := context.Background() + + if EntClient != nil { + IfErrPanic("teardown failed to close database connection", EntClient.Close()) + } + + if DBContainer != nil && DBContainer.Container.IsRunning() { + IfErrPanic("teardown failed to terminate test db container", DBContainer.Container.Terminate(ctx)) + } +} + +// ParseDBURI parses the kind of query language from TESTDB_URI env var and initializes DBContainer as required +func ParseDBURI(ctx context.Context) (string, string, *testcontainersx.DBContainer) { + switch { + // if you don't pass in a database we default to an in memory sqlite + case testDBURI == "": + return dialect.SQLite, "file:ent?mode=memory&cache=shared&_fk=1", nil + case strings.HasPrefix(testDBURI, "sqlite://"): + return dialect.SQLite, strings.TrimPrefix(testDBURI, "sqlite://"), nil + case strings.HasPrefix(testDBURI, "postgres://"), strings.HasPrefix(testDBURI, "postgresql://"): + return dialect.Postgres, testDBURI, nil + case strings.HasPrefix(testDBURI, "docker://"): + dbImage := strings.TrimPrefix(testDBURI, "docker://") + + switch { + case strings.HasPrefix(dbImage, "cockroach"), strings.HasPrefix(dbImage, "cockroachdb"), strings.HasPrefix(dbImage, "crdb"): + cntr, err := testcontainersx.NewCockroachDB(ctx, dbImage) + IfErrPanic("error starting db test container", err) + + return dialect.Postgres, cntr.URI, cntr + case strings.HasPrefix(dbImage, "postgres"): + _, b, _, _ := runtime.Caller(0) + initScriptPath := filepath.Join(filepath.Dir(b), "testdata", "postgres_init.sh") + + cntr, err := testcontainersx.NewPostgresDB(ctx, dbImage, + postgres.WithInitScripts(initScriptPath), + ) + IfErrPanic("error starting db test container", err) + + return dialect.Postgres, cntr.URI, cntr + default: + panic("invalid testcontainer URI, uri: " + testDBURI) + } + + default: + panic("invalid DB URI, uri: " + testDBURI) + } +} diff --git a/internal/testutils/model_builders.go b/internal/testutils/model_builders.go new file mode 100644 index 000000000..2841a3263 --- /dev/null +++ b/internal/testutils/model_builders.go @@ -0,0 +1,148 @@ +package testutils + +import ( + "context" + + "github.com/brianvoe/gofakeit/v6" + "go.infratographer.com/x/gidx" + + ent "go.infratographer.com/load-balancer-api/internal/ent/generated" + "go.infratographer.com/load-balancer-api/internal/ent/generated/pool" +) + +const ( + ownerPrefix = "testown" + locationPrefix = "testloc" + lbPrefix = "loadbal" + minPortNum = 1 + maxPortNum = 65535 +) + +// ProviderBuilder is a provider-like struct for use in generating a provider using the ent client +type ProviderBuilder struct { + Name string + OwnerID gidx.PrefixedID +} + +// MustNew creates a provider from the receiver +func (p *ProviderBuilder) MustNew(ctx context.Context) *ent.Provider { + if p.Name == "" { + p.Name = gofakeit.JobTitle() + } + + if p.OwnerID == "" { + p.OwnerID = gidx.MustNewID(ownerPrefix) + } + + return EntClient.Provider.Create().SetName(p.Name).SetOwnerID(p.OwnerID).SaveX(ctx) +} + +// LoadBalancerBuilder is a loadbalancer-like struct for use in generating a loadbalancer using the ent client +type LoadBalancerBuilder struct { + Name string + OwnerID gidx.PrefixedID + LocationID gidx.PrefixedID + Provider *ent.Provider +} + +// MustNew creates a loadbalancer from the receiver +func (b *LoadBalancerBuilder) MustNew(ctx context.Context) *ent.LoadBalancer { + if b.Provider == nil { + pb := &ProviderBuilder{OwnerID: b.OwnerID} + b.Provider = pb.MustNew(ctx) + } + + if b.Name == "" { + b.Name = gofakeit.AppName() + } + + if b.OwnerID == "" { + b.OwnerID = b.Provider.OwnerID + } + + if b.LocationID == "" { + b.LocationID = gidx.MustNewID(locationPrefix) + } + + return EntClient.LoadBalancer.Create().SetName(b.Name).SetOwnerID(b.OwnerID).SetLocationID(b.LocationID).SetProvider(b.Provider).SaveX(ctx) +} + +// PortBuilder is a port-like struct for use in generating a port using the ent client +type PortBuilder struct { + Name string + LoadBalancerID gidx.PrefixedID + Number int + PoolIDs []gidx.PrefixedID +} + +// MustNew creates a port from the receiver +func (p *PortBuilder) MustNew(ctx context.Context) *ent.Port { + if p.Name == "" { + p.Name = gofakeit.AppName() + } + + if p.LoadBalancerID == "" { + p.LoadBalancerID = gidx.MustNewID(lbPrefix) + } + + if p.Number == 0 { + p.Number = gofakeit.Number(minPortNum, maxPortNum) + } + + return EntClient.Port.Create().SetName(p.Name).SetLoadBalancerID(p.LoadBalancerID).SetNumber(p.Number).AddPoolIDs(p.PoolIDs...).SaveX(ctx) +} + +// PoolBuilder is a pool-like struct for use in generating a pool using the ent client +type PoolBuilder struct { + Name string + OwnerID gidx.PrefixedID + Protocol pool.Protocol +} + +// MustNew creates a pool from the receiver +func (p *PoolBuilder) MustNew(ctx context.Context) *ent.Pool { + if p.Name == "" { + p.Name = gofakeit.AppName() + } + + if p.OwnerID == "" { + p.OwnerID = gidx.MustNewID(ownerPrefix) + } + + if p.Protocol == "" { + p.Protocol = pool.Protocol(gofakeit.RandomString([]string{"tcp", "udp"})) + } + + return EntClient.Pool.Create().SetName(p.Name).SetOwnerID(p.OwnerID).SetProtocol(p.Protocol).SaveX(ctx) +} + +// OriginBuilder is an origin-like struct for use in generating an origin using the ent client +type OriginBuilder struct { + Name string + Target string + PortNumber int + Active bool + PoolID gidx.PrefixedID +} + +// MustNew creates an origin from the receiver +func (o *OriginBuilder) MustNew(ctx context.Context) *ent.Origin { + if o.Name == "" { + o.Name = gofakeit.AppName() + } + + if o.Target == "" { + o.Target = gofakeit.IPv4Address() + } + + if o.PortNumber == 0 { + o.PortNumber = gofakeit.Number(minPortNum, maxPortNum) + } + + if o.PoolID == "" { + pb := &PoolBuilder{} + o.PoolID = pb.MustNew(ctx).ID + } + + return EntClient.Origin.Create().SetName(o.Name).SetTarget(o.Target).SetPortNumber(o.PortNumber).SetActive(o.Active).SetPoolID(o.PoolID).SaveX(ctx) +} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go new file mode 100644 index 000000000..216cf8bf4 --- /dev/null +++ b/internal/testutils/test_utils.go @@ -0,0 +1,43 @@ +// Package testutils provides some utilities that may be useful for testing +package testutils + +import ( + "context" + "log" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + "go.infratographer.com/permissions-api/pkg/permissions/mockpermissions" +) + +// MockPermissions creates a context from the given context with mocks for permission-api methods +func MockPermissions(ctx context.Context) context.Context { + // mock permissions + perms := new(mockpermissions.MockPermissions) + perms.On("CreateAuthRelationships", mock.Anything, mock.Anything, mock.Anything).Return(nil) + perms.On("DeleteAuthRelationships", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + ctx = perms.ContextWithHandler(ctx) + + return ctx +} + +// IfErrPanic conditionally panics on err with msg +func IfErrPanic(msg string, err error) { + if err != nil { + log.Panicf("%s err: %s", msg, err.Error()) + } +} + +// ChannelReceiveWithTimeout returns the next message from channel chan or panics if it timesout before +func ChannelReceiveWithTimeout[T any](t *testing.T, channel <-chan T, timeout time.Duration) (msg T) { + select { + case msg = <-channel: + case <-time.After(timeout): + t.Fatal("timed out waiting to receive from channel") + } + + return +} diff --git a/internal/graphapi/testdata/postgres_init.sh b/internal/testutils/testdata/postgres_init.sh similarity index 100% rename from internal/graphapi/testdata/postgres_init.sh rename to internal/testutils/testdata/postgres_init.sh