Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if version table exist before query #463

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
// SQLDialect abstracts the details of specific SQL dialects
// for goose's few SQL specific statements
type SQLDialect interface {
versionTableExistsSQL() string // sql string to check if the version table exists
createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row
deleteVersionSQL() string // sql string to delete version
Expand Down Expand Up @@ -55,6 +56,11 @@ func SetDialect(d string) error {
// PostgresDialect struct.
type PostgresDialect struct{}

func (pg PostgresDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (pg PostgresDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
Expand All @@ -78,7 +84,7 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}

func (m PostgresDialect) migrationSQL() string {
func (pg PostgresDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}

Expand All @@ -93,6 +99,11 @@ func (pg PostgresDialect) deleteVersionSQL() string {
// MySQLDialect struct.
type MySQLDialect struct{}

func (m MySQLDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (m MySQLDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
Expand Down Expand Up @@ -131,6 +142,11 @@ func (m MySQLDialect) deleteVersionSQL() string {
// SqlServerDialect struct.
type SqlServerDialect struct{}

func (m SqlServerDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (m SqlServerDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
Expand Down Expand Up @@ -181,6 +197,11 @@ func (m SqlServerDialect) deleteVersionSQL() string {
// Sqlite3Dialect struct.
type Sqlite3Dialect struct{}

func (m Sqlite3Dialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM sqlite_master WHERE type='table' AND name='%s';`,
TableName())
}

func (m Sqlite3Dialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
Expand Down Expand Up @@ -218,6 +239,11 @@ func (m Sqlite3Dialect) deleteVersionSQL() string {
// RedshiftDialect struct.
type RedshiftDialect struct{}

func (rs RedshiftDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (rs RedshiftDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
Expand All @@ -241,7 +267,7 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}

func (m RedshiftDialect) migrationSQL() string {
func (rs RedshiftDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}

Expand All @@ -256,6 +282,11 @@ func (rs RedshiftDialect) deleteVersionSQL() string {
// TiDBDialect struct.
type TiDBDialect struct{}

func (m TiDBDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (m TiDBDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
Expand Down Expand Up @@ -294,6 +325,11 @@ func (m TiDBDialect) deleteVersionSQL() string {
// ClickHouseDialect struct.
type ClickHouseDialect struct{}

func (m ClickHouseDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`,
TableName())
}

func (m ClickHouseDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
version_id Int64,
Expand Down Expand Up @@ -332,6 +368,11 @@ func (m ClickHouseDialect) deleteVersionSQL() string {
// VerticaDialect struct.
type VerticaDialect struct{}

func (v VerticaDialect) versionTableExistsSQL() string {
return fmt.Sprintf(`SELECT 1 FROM v_catalog.tables WHERE table_name ILIKE '%s';`,
TableName())
}

func (v VerticaDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id identity(1,1) NOT NULL,
Expand All @@ -355,7 +396,7 @@ func (v VerticaDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}

func (m VerticaDialect) migrationSQL() string {
func (v VerticaDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}

Expand Down
35 changes: 33 additions & 2 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ var (
ErrNoCurrentVersion = errors.New("no current version found")
// ErrNoNextVersion when the next migration version is not found.
ErrNoNextVersion = errors.New("no next version found")
// ErrVersionTableMissing when the goose version table is missing.
ErrVersionTableMissing = errors.New("version table doesn't exists")

// MaxVersion is the maximum allowed version.
MaxVersion int64 = math.MaxInt64

Expand Down Expand Up @@ -289,9 +292,9 @@ func versionFilter(v, current, target int64) bool {
// EnsureDBVersion retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := GetDialect().dbVersionQuery(db)
rows, err := dbVersionQuery(db)
if err != nil {
return 0, createVersionTable(db)
return 0, createVersionTableIfMissing(db, err)
}
defer rows.Close()

Expand Down Expand Up @@ -335,6 +338,34 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
return 0, ErrNoNextVersion
}

// dbVersionQuery query to check if the versions table exists and if so, query the versions table.
// If the versions table does not exist, return an ErrVersionTableMissing error.
func dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
d := GetDialect()

rows, err := db.Query(d.versionTableExistsSQL())
if err != nil {
return nil, err
}
defer rows.Close()

if !rows.Next() {
return nil, ErrVersionTableMissing
}

return d.dbVersionQuery(db)
}

// createVersionTableIfMissing creates the version table if the error is ErrVersionTableMissing.
// Otherwise, returns the original error.
func createVersionTableIfMissing(db *sql.DB, err error) error {
if !errors.Is(err, ErrVersionTableMissing) {
return fmt.Errorf("failed to query db versions: %w", err)
}

return createVersionTable(db)
}

// Create the db version table
// and insert the initial 0 value into it
func createVersionTable(db *sql.DB) error {
Expand Down
2 changes: 1 addition & 1 deletion reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
}

func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) {
rows, err := GetDialect().dbVersionQuery(db)
rows, err := dbVersionQuery(db)
if err != nil {
return map[int64]bool{}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions up.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error {
// listAllDBVersions returns a list of all migrations, ordered ascending.
// TODO(mf): fairly cheap, but a nice-to-have is pagination support.
func listAllDBVersions(db *sql.DB) (Migrations, error) {
rows, err := GetDialect().dbVersionQuery(db)
rows, err := dbVersionQuery(db)
if err != nil {
return nil, createVersionTable(db)
return nil, createVersionTableIfMissing(db, err)
}
var all Migrations
for rows.Next() {
Expand Down