Skip to content

Commit

Permalink
#42 supporting varaiable placeholders into sql migrations (#43)
Browse files Browse the repository at this point in the history
* #42 supporting varaiable placeholders into sql migrations
  • Loading branch information
raoptimus authored Dec 4, 2024
1 parent 737e3f0 commit e9e823f
Show file tree
Hide file tree
Showing 13 changed files with 361 additions and 30 deletions.
8 changes: 8 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dbmigrator

import "github.com/pkg/errors"

var (
ErrMigrationAlreadyExists = errors.New("migration already exists")
ErrAppliedMigrationNotFound = errors.New("applied migration not found")
)
5 changes: 5 additions & 0 deletions internal/dal/repository/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type adapter interface {
DropMigrationHistoryTable(ctx context.Context) error
CreateMigrationHistoryTable(ctx context.Context) error
MigrationsCount(ctx context.Context) (int, error)
ExistsMigration(ctx context.Context, version string) (bool, error)
TableNameWithSchema() string
ForceSafely() bool
}
Expand Down Expand Up @@ -64,6 +65,10 @@ func (r *Repository) MigrationsCount(ctx context.Context) (int, error) {
return r.adapter.MigrationsCount(ctx)
}

func (r *Repository) ExistsMigration(ctx context.Context, version string) (bool, error) {
return r.adapter.ExistsMigration(ctx, version)
}

func (r *Repository) TableNameWithSchema() string {
return r.adapter.TableNameWithSchema()
}
Expand Down
19 changes: 19 additions & 0 deletions internal/dal/repository/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,25 @@ func (c *Clickhouse) MigrationsCount(ctx context.Context) (int, error) {
return count, nil
}

func (c *Clickhouse) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf(`SELECT 1 FROM %s WHERE version = ? AND is_deleted = 0`, c.TableNameWithSchema())
rows, err := c.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
var exists int
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, c.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, c.dbError(err, q)
}

return exists == 1, nil
}

func (c *Clickhouse) TableNameWithSchema() string {
return c.options.SchemaName + "." + c.options.TableName
}
Expand Down
20 changes: 20 additions & 0 deletions internal/dal/repository/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,29 @@ func (m *MySQL) MigrationsCount(ctx context.Context) (int, error) {
if err := rows.Err(); err != nil {
return 0, m.dbError(err, q)
}

return count, nil
}

func (m *MySQL) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf(`SELECT 1 FROM %s WHERE version = ?`, m.TableNameWithSchema())
rows, err := m.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
var exists int
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, m.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, m.dbError(err, q)
}

return exists == 1, nil
}

func (m *MySQL) TableNameWithSchema() string {
return m.options.SchemaName + "." + m.options.TableName
}
Expand Down
21 changes: 21 additions & 0 deletions internal/dal/repository/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func (p *Postgres) DropMigrationHistoryTable(ctx context.Context) error {
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "drop migration history table")
}

return nil
}

Expand All @@ -191,9 +192,29 @@ func (p *Postgres) MigrationsCount(ctx context.Context) (int, error) {
if err := rows.Err(); err != nil {
return 0, p.dbError(err, q)
}

return count, nil
}

func (p *Postgres) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf(`SELECT EXISTS(SELECT 1 FROM %s WHERE version = $1)`, p.TableNameWithSchema())
rows, err := p.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
var exists bool
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, p.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, p.dbError(err, q)
}

return exists, nil
}

func (p *Postgres) TableNameWithSchema() string {
return p.options.SchemaName + "." + p.options.TableName
}
Expand Down
65 changes: 44 additions & 21 deletions internal/migrator/db_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
package migrator

import (
"net/url"
"strings"

"github.com/raoptimus/db-migrator.go/internal/action"
"github.com/raoptimus/db-migrator.go/internal/builder"
"github.com/raoptimus/db-migrator.go/internal/dal/connection"
Expand All @@ -21,12 +24,12 @@ import (

type (
DBService struct {
options *Options
fileNameBuilder FileNameBuilder
migrationServiceFunc func() (*service.Migration, error)
options *Options
fileNameBuilder FileNameBuilder

conn *connection.Connection
repo *repository.Repository
conn *connection.Connection
repo *repository.Repository
service *service.Migration
}
Options struct {
DSN string
Expand All @@ -42,12 +45,10 @@ type (

func New(options *Options) *DBService {
fb := builder.NewFileName(iohelp.StdFile, options.Directory)
dbs := &DBService{
return &DBService{
options: options,
fileNameBuilder: fb,
}
dbs.migrationServiceFunc = dbs.migrationService
return dbs
}

func (s *DBService) Create() *action.Create {
Expand All @@ -61,7 +62,7 @@ func (s *DBService) Create() *action.Create {
}

func (s *DBService) Upgrade() (*action.Upgrade, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}
Expand All @@ -75,56 +76,72 @@ func (s *DBService) Upgrade() (*action.Upgrade, error) {
}

func (s *DBService) Downgrade() (*action.Downgrade, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}

return action.NewDowngrade(serv, s.fileNameBuilder, s.options.Interactive), nil
}

func (s *DBService) To() (*action.To, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}

return action.NewTo(serv, s.fileNameBuilder, s.options.Interactive), nil
}

func (s *DBService) History() (*action.History, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}

return action.NewHistory(serv), nil
}

func (s *DBService) HistoryNew() (*action.HistoryNew, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}

return action.NewHistoryNew(serv), nil
}

func (s *DBService) Redo() (*action.Redo, error) {
serv, err := s.migrationServiceFunc()
serv, err := s.MigrationService()
if err != nil {
return nil, err
}

return action.NewRedo(serv, s.fileNameBuilder, s.options.Interactive), nil
}

func (s *DBService) migrationService() (*service.Migration, error) {
func (s *DBService) MigrationService() (*service.Migration, error) {
if s.service != nil {
return s.service, nil
}

var err error

if s.conn == nil {
conn, err := connection.New(s.options.DSN)
s.conn, err = connection.New(s.options.DSN)
if err != nil {
return nil, err
}
s.conn = conn
}

udsn, _, _ := strings.Cut(s.options.DSN, "@")
dsn, err := url.Parse(udsn + "@")
if err != nil {
return nil, err
}

if s.repo == nil {
repo, err := repository.New(
s.repo, err = repository.New(
s.conn,
&repository.Options{
TableName: s.options.TableName,
Expand All @@ -135,17 +152,23 @@ func (s *DBService) migrationService() (*service.Migration, error) {
if err != nil {
return nil, err
}
s.repo = repo
}

return service.NewMigration(
pass, _ := dsn.User.Password()

s.service = service.NewMigration(
&service.Options{
MaxSQLOutputLength: s.options.MaxSQLOutputLength,
Directory: s.options.Directory,
Compact: s.options.Compact,

Username: dsn.User.Username(),
Password: pass,
},
console.Std,
iohelp.StdFile,
s.repo,
), nil
)

return s.service, nil
}
2 changes: 2 additions & 0 deletions internal/service/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type File interface {

//go:generate mockery --name=Repository --outpkg=mockservice --output=./mockservice
type Repository interface {
// ExistsMigration returns true if version of migration exists
ExistsMigration(ctx context.Context, version string) (bool, error)
// Migrations returns applied migrations history.
Migrations(ctx context.Context, limit int) (entity.Migrations, error)
// HasMigrationHistoryTable returns true if migration history table exists.
Expand Down
31 changes: 22 additions & 9 deletions internal/service/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/pkg/errors"
_ "github.com/raoptimus/db-migrator.go/internal/console"
"github.com/raoptimus/db-migrator.go/internal/dal/entity"
"github.com/raoptimus/db-migrator.go/internal/validator"
"github.com/raoptimus/db-migrator.go/pkg/sqlio"
)

Expand Down Expand Up @@ -122,6 +123,9 @@ func (m *Migration) NewMigrations(ctx context.Context) (entity.Migrations, error

for _, file := range files {
baseFilename = filepath.Base(file)
if err := validator.ValidateFileName(baseFilename); err != nil {
return nil, errors.Wrap(err, baseFilename)
}
groups := regexpFileName.FindStringSubmatch(baseFilename)
if len(groups) != regexpFileNameGroupCount {
return nil, fmt.Errorf("file name %s is invalid", baseFilename)
Expand Down Expand Up @@ -155,8 +159,7 @@ func (m *Migration) ApplySQL(
ctx context.Context,
safely bool,
version,
upSQL,
downSQL string,
upSQL string,
) error {
if version == baseMigration {
return ErrMigrationVersionReserved
Expand All @@ -176,14 +179,14 @@ func (m *Migration) ApplySQL(
}
// todo: save downSQL
m.console.Successf("*** applied %s (time: %.3fs)\n", version, elapsedTime.Seconds())

return nil
}

func (m *Migration) RevertSQL(
ctx context.Context,
safely bool,
version,
upSQL,
downSQL string,
) error {
if version == baseMigration {
Expand All @@ -203,6 +206,7 @@ func (m *Migration) RevertSQL(
return err
}
m.console.Warnf("*** reverted %s (time: %.3fs)\n", version, elapsedTime.Seconds())

return nil
}

Expand Down Expand Up @@ -253,8 +257,8 @@ func (m *Migration) RevertFile(ctx context.Context, entity *entity.Migration, fi
if err := m.repo.RemoveMigration(ctx, entity.Version); err != nil {
return err
}
m.console.Warnf("*** reverted %s (time: %.3fs)\n",
entity.Version, elapsedTime.Seconds())
m.console.Warnf("*** reverted %s (time: %.3fs)\n", entity.Version, elapsedTime.Seconds())

return nil
}

Expand Down Expand Up @@ -292,15 +296,23 @@ func (m *Migration) EndCommand(start time.Time) {
}
}

func (m *Migration) Exists(ctx context.Context, version string) (bool, error) {
return m.repo.ExistsMigration(ctx, version)
}

func (m *Migration) apply(ctx context.Context, scanner *sqlio.Scanner, safely bool) error {
processScanFunc := func(ctx context.Context) error {
var q string
var sql string
for scanner.Scan() {
q = scanner.SQL()
if q == "" {
sql = scanner.SQL()
if sql == "" {
continue
}
if err := m.ExecQuery(ctx, q); err != nil {

sql = strings.ReplaceAll(sql, "{username}", m.options.Username)
sql = strings.ReplaceAll(sql, "{password}", m.options.Password)

if err := m.ExecQuery(ctx, sql); err != nil {
return err
}
}
Expand Down Expand Up @@ -330,5 +342,6 @@ func (m *Migration) scannerByFile(fileName string) (*sqlio.Scanner, error) {
if err != nil {
return nil, errors.Wrapf(err, "migration file %s does not read", fileName)
}

return sqlio.NewScanner(f), nil
}
Loading

0 comments on commit e9e823f

Please sign in to comment.