diff --git a/error.go b/error.go new file mode 100644 index 0000000..a0bec8b --- /dev/null +++ b/error.go @@ -0,0 +1,8 @@ +package dbmigrator + +import "github.com/pkg/errors" + +var ( + ErrMigrationAlreadyExists = errors.New("migration already exists") + ErrAppliedMigrationNotFound = errors.New("applied migration not found") +) diff --git a/internal/dal/repository/adapter.go b/internal/dal/repository/adapter.go index 01f1637..9243bc8 100644 --- a/internal/dal/repository/adapter.go +++ b/internal/dal/repository/adapter.go @@ -20,6 +20,7 @@ type adapter interface { DropMigrationHistoryTable(ctx context.Context) error CreateMigrationHistoryTable(ctx context.Context) error MigrationsCount(ctx context.Context) (int, error) + ExistsMigration(ctx context.Context, version string) (bool, error) TableNameWithSchema() string ForceSafely() bool } @@ -64,6 +65,10 @@ func (r *Repository) MigrationsCount(ctx context.Context) (int, error) { return r.adapter.MigrationsCount(ctx) } +func (r *Repository) ExistsMigration(ctx context.Context, version string) (bool, error) { + return r.adapter.ExistsMigration(ctx, version) +} + func (r *Repository) TableNameWithSchema() string { return r.adapter.TableNameWithSchema() } diff --git a/internal/dal/repository/clickhouse.go b/internal/dal/repository/clickhouse.go index 6a3f249..ce03cdc 100644 --- a/internal/dal/repository/clickhouse.go +++ b/internal/dal/repository/clickhouse.go @@ -209,6 +209,25 @@ func (c *Clickhouse) MigrationsCount(ctx context.Context) (int, error) { return count, nil } +func (c *Clickhouse) ExistsMigration(ctx context.Context, version string) (bool, error) { + q := fmt.Sprintf(`SELECT 1 FROM %s WHERE version = ? AND is_deleted = 0`, c.TableNameWithSchema()) + rows, err := c.conn.QueryContext(ctx, q, version) + if err != nil { + return false, err + } + var exists int + if rows.Next() { + if err := rows.Scan(&exists); err != nil { + return false, c.dbError(err, q) + } + } + if err := rows.Err(); err != nil { + return false, c.dbError(err, q) + } + + return exists == 1, nil +} + func (c *Clickhouse) TableNameWithSchema() string { return c.options.SchemaName + "." + c.options.TableName } diff --git a/internal/dal/repository/mysql.go b/internal/dal/repository/mysql.go index d60daf4..91510f7 100644 --- a/internal/dal/repository/mysql.go +++ b/internal/dal/repository/mysql.go @@ -179,9 +179,29 @@ func (m *MySQL) MigrationsCount(ctx context.Context) (int, error) { if err := rows.Err(); err != nil { return 0, m.dbError(err, q) } + return count, nil } +func (m *MySQL) ExistsMigration(ctx context.Context, version string) (bool, error) { + q := fmt.Sprintf(`SELECT 1 FROM %s WHERE version = ?`, m.TableNameWithSchema()) + rows, err := m.conn.QueryContext(ctx, q, version) + if err != nil { + return false, err + } + var exists int + if rows.Next() { + if err := rows.Scan(&exists); err != nil { + return false, m.dbError(err, q) + } + } + if err := rows.Err(); err != nil { + return false, m.dbError(err, q) + } + + return exists == 1, nil +} + func (m *MySQL) TableNameWithSchema() string { return m.options.SchemaName + "." + m.options.TableName } diff --git a/internal/dal/repository/postgres.go b/internal/dal/repository/postgres.go index b74139a..7e7577a 100644 --- a/internal/dal/repository/postgres.go +++ b/internal/dal/repository/postgres.go @@ -172,6 +172,7 @@ func (p *Postgres) DropMigrationHistoryTable(ctx context.Context) error { if _, err := p.conn.ExecContext(ctx, q); err != nil { return errors.Wrap(p.dbError(err, q), "drop migration history table") } + return nil } @@ -191,9 +192,29 @@ func (p *Postgres) MigrationsCount(ctx context.Context) (int, error) { if err := rows.Err(); err != nil { return 0, p.dbError(err, q) } + return count, nil } +func (p *Postgres) ExistsMigration(ctx context.Context, version string) (bool, error) { + q := fmt.Sprintf(`SELECT EXISTS(SELECT 1 FROM %s WHERE version = $1)`, p.TableNameWithSchema()) + rows, err := p.conn.QueryContext(ctx, q, version) + if err != nil { + return false, err + } + var exists bool + if rows.Next() { + if err := rows.Scan(&exists); err != nil { + return false, p.dbError(err, q) + } + } + if err := rows.Err(); err != nil { + return false, p.dbError(err, q) + } + + return exists, nil +} + func (p *Postgres) TableNameWithSchema() string { return p.options.SchemaName + "." + p.options.TableName } diff --git a/internal/migrator/db_service.go b/internal/migrator/db_service.go index 1aaac33..14175fb 100644 --- a/internal/migrator/db_service.go +++ b/internal/migrator/db_service.go @@ -9,6 +9,8 @@ package migrator import ( + "net/url" + "github.com/raoptimus/db-migrator.go/internal/action" "github.com/raoptimus/db-migrator.go/internal/builder" "github.com/raoptimus/db-migrator.go/internal/dal/connection" @@ -21,12 +23,12 @@ import ( type ( DBService struct { - options *Options - fileNameBuilder FileNameBuilder - migrationServiceFunc func() (*service.Migration, error) + options *Options + fileNameBuilder FileNameBuilder - conn *connection.Connection - repo *repository.Repository + conn *connection.Connection + repo *repository.Repository + service *service.Migration } Options struct { DSN string @@ -42,12 +44,10 @@ type ( func New(options *Options) *DBService { fb := builder.NewFileName(iohelp.StdFile, options.Directory) - dbs := &DBService{ + return &DBService{ options: options, fileNameBuilder: fb, } - dbs.migrationServiceFunc = dbs.migrationService - return dbs } func (s *DBService) Create() *action.Create { @@ -61,7 +61,7 @@ func (s *DBService) Create() *action.Create { } func (s *DBService) Upgrade() (*action.Upgrade, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } @@ -75,56 +75,71 @@ func (s *DBService) Upgrade() (*action.Upgrade, error) { } func (s *DBService) Downgrade() (*action.Downgrade, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } + return action.NewDowngrade(serv, s.fileNameBuilder, s.options.Interactive), nil } func (s *DBService) To() (*action.To, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } + return action.NewTo(serv, s.fileNameBuilder, s.options.Interactive), nil } func (s *DBService) History() (*action.History, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } + return action.NewHistory(serv), nil } func (s *DBService) HistoryNew() (*action.HistoryNew, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } + return action.NewHistoryNew(serv), nil } func (s *DBService) Redo() (*action.Redo, error) { - serv, err := s.migrationServiceFunc() + serv, err := s.MigrationService() if err != nil { return nil, err } + return action.NewRedo(serv, s.fileNameBuilder, s.options.Interactive), nil } -func (s *DBService) migrationService() (*service.Migration, error) { +func (s *DBService) MigrationService() (*service.Migration, error) { + if s.service != nil { + return s.service, nil + } + + var err error + if s.conn == nil { - conn, err := connection.New(s.options.DSN) + s.conn, err = connection.New(s.options.DSN) if err != nil { return nil, err } - s.conn = conn + } + + dsn, err := url.Parse(s.options.DSN) + if err != nil { + return nil, err } if s.repo == nil { - repo, err := repository.New( + s.repo, err = repository.New( s.conn, &repository.Options{ TableName: s.options.TableName, @@ -135,17 +150,23 @@ func (s *DBService) migrationService() (*service.Migration, error) { if err != nil { return nil, err } - s.repo = repo } - return service.NewMigration( + pass, _ := dsn.User.Password() + + s.service = service.NewMigration( &service.Options{ MaxSQLOutputLength: s.options.MaxSQLOutputLength, Directory: s.options.Directory, Compact: s.options.Compact, + + Username: dsn.User.Username(), + Password: pass, }, console.Std, iohelp.StdFile, s.repo, - ), nil + ) + + return s.service, nil } diff --git a/internal/service/dependencies.go b/internal/service/dependencies.go index 54b1008..ffc938e 100644 --- a/internal/service/dependencies.go +++ b/internal/service/dependencies.go @@ -43,6 +43,8 @@ type File interface { //go:generate mockery --name=Repository --outpkg=mockservice --output=./mockservice type Repository interface { + // ExistsMigration return true if version of migration exists + ExistsMigration(ctx context.Context, version string) (bool, error) // Migrations returns applied migrations history. Migrations(ctx context.Context, limit int) (entity.Migrations, error) // HasMigrationHistoryTable returns true if migration history table exists. diff --git a/internal/service/migration.go b/internal/service/migration.go index 7ac29ab..fc95f19 100644 --- a/internal/service/migration.go +++ b/internal/service/migration.go @@ -155,8 +155,7 @@ func (m *Migration) ApplySQL( ctx context.Context, safely bool, version, - upSQL, - downSQL string, + upSQL string, ) error { if version == baseMigration { return ErrMigrationVersionReserved @@ -176,6 +175,7 @@ func (m *Migration) ApplySQL( } // todo: save downSQL m.console.Successf("*** applied %s (time: %.3fs)\n", version, elapsedTime.Seconds()) + return nil } @@ -183,7 +183,6 @@ func (m *Migration) RevertSQL( ctx context.Context, safely bool, version, - upSQL, downSQL string, ) error { if version == baseMigration { @@ -203,6 +202,7 @@ func (m *Migration) RevertSQL( return err } m.console.Warnf("*** reverted %s (time: %.3fs)\n", version, elapsedTime.Seconds()) + return nil } @@ -253,8 +253,8 @@ func (m *Migration) RevertFile(ctx context.Context, entity *entity.Migration, fi if err := m.repo.RemoveMigration(ctx, entity.Version); err != nil { return err } - m.console.Warnf("*** reverted %s (time: %.3fs)\n", - entity.Version, elapsedTime.Seconds()) + m.console.Warnf("*** reverted %s (time: %.3fs)\n", entity.Version, elapsedTime.Seconds()) + return nil } @@ -292,15 +292,23 @@ func (m *Migration) EndCommand(start time.Time) { } } +func (m *Migration) Exists(ctx context.Context, version string) (bool, error) { + return m.repo.ExistsMigration(ctx, version) +} + func (m *Migration) apply(ctx context.Context, scanner *sqlio.Scanner, safely bool) error { processScanFunc := func(ctx context.Context) error { - var q string + var sql string for scanner.Scan() { - q = scanner.SQL() - if q == "" { + sql = scanner.SQL() + if sql == "" { continue } - if err := m.ExecQuery(ctx, q); err != nil { + + sql = strings.ReplaceAll(sql, "{username}", m.options.Username) + sql = strings.ReplaceAll(sql, "{password}", m.options.Password) + + if err := m.ExecQuery(ctx, sql); err != nil { return err } } @@ -330,5 +338,6 @@ func (m *Migration) scannerByFile(fileName string) (*sqlio.Scanner, error) { if err != nil { return nil, errors.Wrapf(err, "migration file %s does not read", fileName) } + return sqlio.NewScanner(f), nil } diff --git a/internal/service/options.go b/internal/service/options.go index d0c2e23..390bb20 100644 --- a/internal/service/options.go +++ b/internal/service/options.go @@ -12,4 +12,7 @@ type Options struct { MaxSQLOutputLength int Directory string Compact bool + + Username string + Password string } diff --git a/service.go b/service.go new file mode 100644 index 0000000..4b8570c --- /dev/null +++ b/service.go @@ -0,0 +1,77 @@ +package dbmigrator + +import ( + "context" + + "github.com/raoptimus/db-migrator.go/internal/migrator" +) + +type ( + Options struct { + DSN string + // table name to history of migrations + TableName string + // cluster name to clickhouse + ClusterName string + // is replicated used to clickhouse? + Replicated bool + } + DBService struct { + dbs *migrator.DBService + opts *Options + } +) + +func NewDBService(opts *Options) *DBService { + return &DBService{ + dbs: migrator.New(&migrator.Options{ + DSN: opts.DSN, + Directory: "", + TableName: opts.TableName, + ClusterName: opts.ClusterName, + Replicated: opts.Replicated, + Compact: true, + Interactive: true, + MaxSQLOutputLength: 0, + }), + opts: opts, + } +} + +// Upgrade apply changes to db. apply specific version of migration. +func (d *DBService) Upgrade(ctx context.Context, version, sql string, safety bool) error { + ms, err := d.dbs.MigrationService() + if err != nil { + return err + } + + exists, err := ms.Exists(ctx, version) + if err != nil { + return err + } + + if exists { + return ErrMigrationAlreadyExists + } + + return ms.ApplySQL(ctx, safety, version, sql) +} + +// Downgrade revert changes to db. revert specific version of migration. +func (d *DBService) Downgrade(ctx context.Context, version, sql string, safety bool) error { + ms, err := d.dbs.MigrationService() + if err != nil { + return err + } + + exists, err := ms.Exists(ctx, version) + if err != nil { + return err + } + + if !exists { + return ErrAppliedMigrationNotFound + } + + return ms.RevertSQL(ctx, safety, version, sql) +}