diff --git a/dialect.go b/dialect.go index abda163ea..0dbacd11e 100644 --- a/dialect.go +++ b/dialect.go @@ -8,6 +8,7 @@ import ( // SQLDialect abstracts the details of specific SQL dialects // for goose's few SQL specific statements type SQLDialect interface { + versionTableExistsSQL() string // sql string to check if the version table exists createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row deleteVersionSQL() string // sql string to delete version @@ -55,6 +56,11 @@ func SetDialect(d string) error { // PostgresDialect struct. type PostgresDialect struct{} +func (pg PostgresDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (pg PostgresDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, @@ -78,7 +84,7 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m PostgresDialect) migrationSQL() string { +func (pg PostgresDialect) migrationSQL() string { return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) } @@ -93,6 +99,11 @@ func (pg PostgresDialect) deleteVersionSQL() string { // MySQLDialect struct. type MySQLDialect struct{} +func (m MySQLDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (m MySQLDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, @@ -131,6 +142,11 @@ func (m MySQLDialect) deleteVersionSQL() string { // SqlServerDialect struct. type SqlServerDialect struct{} +func (m SqlServerDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (m SqlServerDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, @@ -181,6 +197,11 @@ func (m SqlServerDialect) deleteVersionSQL() string { // Sqlite3Dialect struct. type Sqlite3Dialect struct{} +func (m Sqlite3Dialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM sqlite_master WHERE type='table' AND name='%s';`, + TableName()) +} + func (m Sqlite3Dialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -218,6 +239,11 @@ func (m Sqlite3Dialect) deleteVersionSQL() string { // RedshiftDialect struct. type RedshiftDialect struct{} +func (rs RedshiftDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (rs RedshiftDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id integer NOT NULL identity(1, 1), @@ -241,7 +267,7 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m RedshiftDialect) migrationSQL() string { +func (rs RedshiftDialect) migrationSQL() string { return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) } @@ -256,6 +282,11 @@ func (rs RedshiftDialect) deleteVersionSQL() string { // TiDBDialect struct. type TiDBDialect struct{} +func (m TiDBDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (m TiDBDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, @@ -294,6 +325,11 @@ func (m TiDBDialect) deleteVersionSQL() string { // ClickHouseDialect struct. type ClickHouseDialect struct{} +func (m ClickHouseDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM information_schema.tables WHERE table_name = '%s';`, + TableName()) +} + func (m ClickHouseDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( version_id Int64, @@ -332,6 +368,11 @@ func (m ClickHouseDialect) deleteVersionSQL() string { // VerticaDialect struct. type VerticaDialect struct{} +func (v VerticaDialect) versionTableExistsSQL() string { + return fmt.Sprintf(`SELECT 1 FROM v_catalog.tables WHERE table_name ILIKE '%s';`, + TableName()) +} + func (v VerticaDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id identity(1,1) NOT NULL, @@ -355,7 +396,7 @@ func (v VerticaDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } -func (m VerticaDialect) migrationSQL() string { +func (v VerticaDialect) migrationSQL() string { return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) } diff --git a/migrate.go b/migrate.go index c92f7a8a2..7c6b368ad 100644 --- a/migrate.go +++ b/migrate.go @@ -17,6 +17,9 @@ var ( ErrNoCurrentVersion = errors.New("no current version found") // ErrNoNextVersion when the next migration version is not found. ErrNoNextVersion = errors.New("no next version found") + // ErrVersionTableMissing when the goose version table is missing. + ErrVersionTableMissing = errors.New("version table doesn't exists") + // MaxVersion is the maximum allowed version. MaxVersion int64 = math.MaxInt64 @@ -289,9 +292,9 @@ func versionFilter(v, current, target int64) bool { // EnsureDBVersion retrieves the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { - rows, err := GetDialect().dbVersionQuery(db) + rows, err := dbVersionQuery(db) if err != nil { - return 0, createVersionTable(db) + return 0, createVersionTableIfMissing(db, err) } defer rows.Close() @@ -335,6 +338,34 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { return 0, ErrNoNextVersion } +// dbVersionQuery query to check if the versions table exists and if so, query the versions table. +// If the versions table does not exist, return an ErrVersionTableMissing error. +func dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + d := GetDialect() + + rows, err := db.Query(d.versionTableExistsSQL()) + if err != nil { + return nil, err + } + defer rows.Close() + + if !rows.Next() { + return nil, ErrVersionTableMissing + } + + return d.dbVersionQuery(db) +} + +// createVersionTableIfMissing creates the version table if the error is ErrVersionTableMissing. +// Otherwise, returns the original error. +func createVersionTableIfMissing(db *sql.DB, err error) error { + if !errors.Is(err, ErrVersionTableMissing) { + return fmt.Errorf("failed to query db versions: %w", err) + } + + return createVersionTable(db) +} + // Create the db version table // and insert the initial 0 value into it func createVersionTable(db *sql.DB) error { diff --git a/reset.go b/reset.go index 258841fad..65ae043cd 100644 --- a/reset.go +++ b/reset.go @@ -39,7 +39,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { } func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) { - rows, err := GetDialect().dbVersionQuery(db) + rows, err := dbVersionQuery(db) if err != nil { return map[int64]bool{}, nil } diff --git a/up.go b/up.go index 1d668e38c..600511d85 100644 --- a/up.go +++ b/up.go @@ -223,9 +223,9 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { // listAllDBVersions returns a list of all migrations, ordered ascending. // TODO(mf): fairly cheap, but a nice-to-have is pagination support. func listAllDBVersions(db *sql.DB) (Migrations, error) { - rows, err := GetDialect().dbVersionQuery(db) + rows, err := dbVersionQuery(db) if err != nil { - return nil, createVersionTable(db) + return nil, createVersionTableIfMissing(db, err) } var all Migrations for rows.Next() {