Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: postgres migration table existence check #860

Merged
merged 8 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const (
)

// NewStore returns a new [Store] implementation for the given dialect.
func NewStore(dialect Dialect, tablename string) (Store, error) {
if tablename == "" {
func NewStore(dialect Dialect, tableName string) (Store, error) {
if tableName == "" {
return nil, errors.New("table name must not be empty")
}
if dialect == "" {
Expand All @@ -52,40 +52,40 @@ func NewStore(dialect Dialect, tablename string) (Store, error) {
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: tablename,
querier: querier,
tableName: tableName,
querier: dialectquery.NewQueryController(querier),
}, nil
}

type store struct {
tablename string
querier dialectquery.Querier
tableName string
querier *dialectquery.QueryController
}

var _ Store = (*store)(nil)

func (s *store) Tablename() string {
return s.tablename
return s.tableName
}

func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
q := s.querier.CreateTable(s.tableName)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
return fmt.Errorf("failed to create version table %q: %w", s.tableName, err)
}
return nil
}

func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
q := s.querier.InsertVersion(s.tablename)
q := s.querier.InsertVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
}
return nil
}

func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
q := s.querier.DeleteVersion(s.tablename)
q := s.querier.DeleteVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
Expand All @@ -97,7 +97,7 @@ func (s *store) GetMigration(
db DBTxConn,
version int64,
) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
q := s.querier.GetMigrationByVersion(s.tableName)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
Expand All @@ -112,7 +112,7 @@ func (s *store) GetMigration(
}

func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
q := s.querier.GetLatestVersion(s.tablename)
q := s.querier.GetLatestVersion(s.tableName)
var version sql.NullInt64
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
return -1, fmt.Errorf("failed to get latest version: %w", err)
Expand All @@ -127,7 +127,7 @@ func (s *store) ListMigrations(
ctx context.Context,
db DBTxConn,
) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
q := s.querier.ListMigrations(s.tableName)
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
Expand All @@ -147,3 +147,26 @@ func (s *store) ListMigrations(
}
return migrations, nil
}

//
//
//
// Additional methods that are not part of the core Store interface, but are extended by the
// [controller.StoreController] type.
//
//
//

func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
q := s.querier.TableExists(s.tableName)
if q == "" {
return false, errors.ErrUnsupported
}
var exists bool
// Note, we do not pass the table name as an argument to the query, as the query should be
// pre-defined by the dialect.
if err := db.QueryRowContext(ctx, q).Scan(&exists); err != nil {
return false, fmt.Errorf("failed to check if table exists: %w", err)
}
return exists, nil
}
40 changes: 40 additions & 0 deletions internal/controller/controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package controller

import (
"context"
"errors"

"github.com/pressly/goose/v3/database"
)

// A StoreController is used by the goose package to interact with a database. This type is a
// wrapper around the Store interface, but can be extended to include additional (optional) methods
// that are not part of the core Store interface.
type StoreController struct{ database.Store }

var _ database.Store = (*StoreController)(nil)

// NewStoreController returns a new StoreController that wraps the given Store.
//
// If the Store implements the following optional methods, the StoreController will call them as
// appropriate:
//
// - TableExists(context.Context, DBTxConn) (bool, error)
//
// If the Store does not implement a method, it will either return a [errors.ErrUnsupported] error
// or fall back to the default behavior.
func NewStoreController(store database.Store) *StoreController {
return &StoreController{store}
}

// TableExists is an optional method that checks if the version table exists in the database. It is
// recommended to implement this method if the database supports it, as it can be used to optimize
// certain operations.
func (c *StoreController) TableExists(ctx context.Context, db database.DBTxConn) (bool, error) {
if t, ok := c.Store.(interface {
TableExists(ctx context.Context, db database.DBTxConn) (bool, error)
}); ok {
return t.TableExists(ctx, db)
}
return false, errors.ErrUnsupported
}
22 changes: 22 additions & 0 deletions internal/dialect/dialectquery/dialectquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,25 @@ type Querier interface {
// table. Returns a nullable int64 value.
GetLatestVersion(tableName string) string
}

var _ Querier = (*QueryController)(nil)

type QueryController struct{ Querier }

// NewQueryController returns a new QueryController that wraps the given Querier.
func NewQueryController(querier Querier) *QueryController {
return &QueryController{Querier: querier}
}

// Optional methods

// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// Returns a boolean value.
func (c *QueryController) TableExists(tableName string) string {
if t, ok := c.Querier.(interface{ TableExists(string) string }); ok {
return t.TableExists(tableName)
}
return ""
}
26 changes: 25 additions & 1 deletion internal/dialect/dialectquery/postgres.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
package dialectquery

import "fmt"
import (
"fmt"
"strings"
)

const (
// defaultSchemaName is the default schema name for Postgres.
//
// https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PUBLIC
defaultSchemaName = "public"
)

type Postgres struct{}

Expand Down Expand Up @@ -40,3 +50,17 @@ func (p *Postgres) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) TableExists(tableName string) string {
schemaName, tableName := parseTableIdentifier(tableName)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should handle cases where the table name contains the schema name; otherwise, assumed public.

q := `SELECT EXISTS ( SELECT FROM pg_tables WHERE schemaname = '%s' AND tablename = '%s' )`
return fmt.Sprintf(q, schemaName, tableName)
}

func parseTableIdentifier(name string) (schema, table string) {
schema, table, found := strings.Cut(name, ".")
if !found {
return defaultSchemaName, name
}
return schema, table
}
5 changes: 3 additions & 2 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/controller"
"github.com/pressly/goose/v3/internal/gooseutil"
"github.com/pressly/goose/v3/internal/sqlparser"
"go.uber.org/multierr"
Expand All @@ -24,7 +25,7 @@ type Provider struct {
mu sync.Mutex

db *sql.DB
store database.Store
store *controller.StoreController
versionTableOnce sync.Once

fsys fs.FS
Expand Down Expand Up @@ -141,7 +142,7 @@ func newProvider(
db: db,
fsys: fsys,
cfg: cfg,
store: store,
store: controller.NewStoreController(store),
migrations: migrations,
}, nil
}
Expand Down
34 changes: 19 additions & 15 deletions provider_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,25 +296,29 @@ func (p *Provider) tryEnsureVersionTable(ctx context.Context, conn *sql.Conn) er
b := retry.NewConstant(1 * time.Second)
b = retry.WithMaxRetries(3, b)
return retry.Do(ctx, b, func(ctx context.Context) error {
if e, ok := p.store.(interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
Copy link
Collaborator Author

@mfridman mfridman Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of this PR was to make this a "well-defined" method part of the extended Store interface.

By having this inlined, makes it difficult to understand where else in the codebase there might be such hidden interfaces.

But with the concept of the "controller" there's a single place where users (and us) can look to discover these.

}); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}
if exists {
return nil
}
} else {
// This chicken-and-egg behavior is the fallback for all existing implementations of the
// Store interface. We check if the version table exists by querying for the initial
// version, but the table may not exist yet. It's important this runs outside of a
// transaction to avoid failing the transaction.
exists, err := p.store.TableExists(ctx, conn)
if err == nil && exists {
return nil
} else if err != nil && errors.Is(err, errors.ErrUnsupported) {
// Fallback strategy for checking table existence:
//
// When direct table existence checks aren't supported, we attempt to query the initial
// migration (version 0). This approach has two implications:
//
// 1. If the table exists, the query succeeds and confirms existence
// 2. If the table doesn't exist, the query fails and generates an error log
//
// Note: This check must occur outside any transaction, as a failed query would
// otherwise cause the entire transaction to roll back. The error logs generated by this
// approach are expected and can be safely ignored.
if res, err := p.store.GetMigration(ctx, conn, 0); err == nil && res != nil {
return nil
}
// Fallthrough to create the table.
} else if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}

if err := beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
return err
Expand Down