diff --git a/CHANGELOG.md b/CHANGELOG.md index 31e52c295..d8f09c1ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Store implementations can **optionally** implement the `TableExists` method to provide optimized + table existence checks (#860) + - Default postgres Store implementation updated to use `pg_tables` system catalog, more to follow + - Backward compatible change - existing implementations will continue to work without modification + +```go +TableExists(ctx context.Context, db database.DBTxConn) (bool, error) +``` + ## [v3.23.0] - Add `WithLogger` to `NewProvider` to allow custom loggers (#833) diff --git a/database/dialect.go b/database/dialect.go index ba2da5cab..9f138f560 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -53,13 +53,13 @@ func NewStore(dialect Dialect, tablename string) (Store, error) { } return &store{ tablename: tablename, - querier: querier, + querier: dialectquery.NewQueryController(querier), }, nil } type store struct { tablename string - querier dialectquery.Querier + querier *dialectquery.QueryController } var _ Store = (*store)(nil) @@ -147,3 +147,26 @@ func (s *store) ListMigrations( } return migrations, nil } + +// +// +// +// Additional methods that are not part of the core Store interface, but are extended by the +// [controller.StoreController] type. +// +// +// + +func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) { + q := s.querier.TableExists(s.tablename) + if q == "" { + return false, errors.ErrUnsupported + } + var exists bool + // Note, we do not pass the table name as an argument to the query, as the query should be + // pre-defined by the dialect. + if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil { + return false, fmt.Errorf("failed to check if table exists: %w", err) + } + return exists, nil +} diff --git a/database/store_extended.go b/database/store_extended.go new file mode 100644 index 000000000..e3aae4d9a --- /dev/null +++ b/database/store_extended.go @@ -0,0 +1,33 @@ +package database + +import "context" + +// StoreExtender is an extension of the Store interface that provides optional optimizations and +// database-specific features. While not required by the core goose package, implementing these +// methods can improve performance and functionality for specific databases. +// +// IMPORTANT: This interface may be expanded in future versions. Implementors MUST be prepared to +// update their implementations when new methods are added, either by implementing the new +// functionality or returning [errors.ErrUnsupported]. +// +// The goose package handles these extended capabilities through a [controller.StoreController], +// which automatically uses optimized methods when available while falling back to default behavior +// when they're not implemented. +// +// Example usage to verify implementation: +// +// var _ StoreExtender = (*CustomStoreExtended)(nil) +// +// In short, it's exported to allows implementors to have a compile-time check that they are +// implementing the interface correctly. +type StoreExtender interface { + Store + + // TableExists checks if the migrations table exists in the database. Implementing this method + // allows goose to optimize table existence checks by using database-specific system catalogs + // (e.g., pg_tables for PostgreSQL, sqlite_master for SQLite) instead of generic SQL queries. + // + // Return [errors.ErrUnsupported] if the database does not provide an efficient way to check + // table existence. + TableExists(ctx context.Context, db DBTxConn) (bool, error) +} diff --git a/internal/controller/controller.go b/internal/controller/controller.go new file mode 100644 index 000000000..1d26e55b2 --- /dev/null +++ b/internal/controller/controller.go @@ -0,0 +1,37 @@ +package controller + +import ( + "context" + "errors" + + "github.com/pressly/goose/v3/database" +) + +// A StoreController is used by the goose package to interact with a database. This type is a +// wrapper around the Store interface, but can be extended to include additional (optional) methods +// that are not part of the core Store interface. +type StoreController struct{ database.Store } + +var _ database.StoreExtender = (*StoreController)(nil) + +// NewStoreController returns a new StoreController that wraps the given Store. +// +// If the Store implements the following optional methods, the StoreController will call them as +// appropriate: +// +// - TableExists(context.Context, DBTxConn) (bool, error) +// +// If the Store does not implement a method, it will either return a [errors.ErrUnsupported] error +// or fall back to the default behavior. +func NewStoreController(store database.Store) *StoreController { + return &StoreController{store} +} + +func (c *StoreController) TableExists(ctx context.Context, db database.DBTxConn) (bool, error) { + if t, ok := c.Store.(interface { + TableExists(ctx context.Context, db database.DBTxConn) (bool, error) + }); ok { + return t.TableExists(ctx, db) + } + return false, errors.ErrUnsupported +} diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index 5e10e46e4..49ce96f47 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -25,3 +25,25 @@ type Querier interface { // table. Returns a nullable int64 value. GetLatestVersion(tableName string) string } + +var _ Querier = (*QueryController)(nil) + +type QueryController struct{ Querier } + +// NewQueryController returns a new QueryController that wraps the given Querier. +func NewQueryController(querier Querier) *QueryController { + return &QueryController{Querier: querier} +} + +// Optional methods + +// TableExists returns the SQL query string to check if the version table exists. If the Querier +// does not implement this method, it will return an empty string. +// +// Returns a boolean value. +func (c *QueryController) TableExists(tableName string) string { + if t, ok := c.Querier.(interface{ TableExists(string) string }); ok { + return t.TableExists(tableName) + } + return "" +} diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index 2def6c6ca..facbc2ff7 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -1,6 +1,16 @@ package dialectquery -import "fmt" +import ( + "fmt" + "strings" +) + +const ( + // defaultSchemaName is the default schema name for Postgres. + // + // https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PUBLIC + defaultSchemaName = "public" +) type Postgres struct{} @@ -40,3 +50,17 @@ func (p *Postgres) GetLatestVersion(tableName string) string { q := `SELECT max(version_id) FROM %s` return fmt.Sprintf(q, tableName) } + +func (p *Postgres) TableExists(tableName string) string { + schemaName, tableName := parseTableIdentifier(tableName) + q := `SELECT EXISTS ( SELECT FROM pg_tables WHERE schemaname = '%s' AND tablename = '%s' )` + return fmt.Sprintf(q, schemaName, tableName) +} + +func parseTableIdentifier(name string) (schema, table string) { + schema, table, found := strings.Cut(name, ".") + if !found { + return defaultSchemaName, name + } + return schema, table +} diff --git a/provider.go b/provider.go index ec3b8749a..a4163a667 100644 --- a/provider.go +++ b/provider.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/pressly/goose/v3/database" + "github.com/pressly/goose/v3/internal/controller" "github.com/pressly/goose/v3/internal/gooseutil" "github.com/pressly/goose/v3/internal/sqlparser" "go.uber.org/multierr" @@ -24,7 +25,7 @@ type Provider struct { mu sync.Mutex db *sql.DB - store database.Store + store *controller.StoreController versionTableOnce sync.Once fsys fs.FS @@ -141,7 +142,7 @@ func newProvider( db: db, fsys: fsys, cfg: cfg, - store: store, + store: controller.NewStoreController(store), migrations: migrations, }, nil } diff --git a/provider_run.go b/provider_run.go index 66c23c67d..c083d3967 100644 --- a/provider_run.go +++ b/provider_run.go @@ -296,25 +296,29 @@ func (p *Provider) tryEnsureVersionTable(ctx context.Context, conn *sql.Conn) er b := retry.NewConstant(1 * time.Second) b = retry.WithMaxRetries(3, b) return retry.Do(ctx, b, func(ctx context.Context) error { - if e, ok := p.store.(interface { - TableExists(context.Context, database.DBTxConn, string) (bool, error) - }); ok { - exists, err := e.TableExists(ctx, conn, p.store.Tablename()) - if err != nil { - return fmt.Errorf("failed to check if version table exists: %w", err) - } - if exists { - return nil - } - } else { - // This chicken-and-egg behavior is the fallback for all existing implementations of the - // Store interface. We check if the version table exists by querying for the initial - // version, but the table may not exist yet. It's important this runs outside of a - // transaction to avoid failing the transaction. + exists, err := p.store.TableExists(ctx, conn) + if err == nil && exists { + return nil + } else if err != nil && errors.Is(err, errors.ErrUnsupported) { + // Fallback strategy for checking table existence: + // + // When direct table existence checks aren't supported, we attempt to query the initial + // migration (version 0). This approach has two implications: + // + // 1. If the table exists, the query succeeds and confirms existence + // 2. If the table doesn't exist, the query fails and generates an error log + // + // Note: This check must occur outside any transaction, as a failed query would + // otherwise cause the entire transaction to roll back. The error logs generated by this + // approach are expected and can be safely ignored. if res, err := p.store.GetMigration(ctx, conn, 0); err == nil && res != nil { return nil } + // Fallthrough to create the table. + } else if err != nil { + return fmt.Errorf("failed to check if version table exists: %w", err) } + if err := beginTx(ctx, conn, func(tx *sql.Tx) error { if err := p.store.CreateVersionTable(ctx, tx); err != nil { return err diff --git a/provider_run_test.go b/provider_run_test.go index 4598e9510..aedfa7fa9 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -745,15 +745,17 @@ func TestGoMigrationPanic(t *testing.T) { func TestCustomStoreTableExists(t *testing.T) { t.Parallel() - + db := newDB(t) store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename) require.NoError(t, err) - p, err := goose.NewProvider("", newDB(t), newFsys(), - goose.WithStore(&customStoreSQLite3{store}), - ) - require.NoError(t, err) - _, err = p.Up(context.Background()) - require.NoError(t, err) + for i := 0; i < 2; i++ { + p, err := goose.NewProvider("", db, newFsys(), + goose.WithStore(&customStoreSQLite3{store}), + ) + require.NoError(t, err) + _, err = p.Up(context.Background()) + require.NoError(t, err) + } } func TestProviderApply(t *testing.T) { @@ -842,14 +844,14 @@ func TestPending(t *testing.T) { }) } -type customStoreSQLite3 struct { - database.Store -} +var _ database.StoreExtender = (*customStoreSQLite3)(nil) + +type customStoreSQLite3 struct{ database.Store } -func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) { - q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists` +func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn) (bool, error) { + q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=?) AS table_exists` var exists bool - if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil { + if err := db.QueryRowContext(ctx, q, s.Tablename()).Scan(&exists); err != nil { return false, err } return exists, nil