diff --git a/base/commands/migration/migration_start.go b/base/commands/migration/migration_start.go index b714641e..58a6f216 100644 --- a/base/commands/migration/migration_start.go +++ b/base/commands/migration/migration_start.go @@ -65,10 +65,12 @@ In order to cancel the migration, use the 'cancel' command. return err } defer func() { - maybePrintWarnings(ctx, ec, sts.ci, mID) - finalizeErr := finalizeMigration(ctx, ec, sts.ci, mID, ec.Props().GetString(flagOutputDir)) - if err == nil { - err = finalizeErr + if sts.ci != nil { + maybePrintWarnings(ctx, ec, sts.ci, mID) + finalizeErr := finalizeMigration(ctx, ec, sts.ci, mID, ec.Props().GetString(flagOutputDir)) + if err == nil { + err = finalizeErr + } } }() sp := stage.NewFixedProvider(sts.Build(ctx, ec)...) diff --git a/base/commands/migration/migration_status.go b/base/commands/migration/migration_status.go index e1e092b9..53f77b61 100644 --- a/base/commands/migration/migration_status.go +++ b/base/commands/migration/migration_status.go @@ -27,19 +27,20 @@ func (s StatusCmd) Exec(ctx context.Context, ec plug.ExecContext) (err error) { ec.PrintlnUnnecessary("") ec.PrintlnUnnecessary(banner) sts := NewStatusStages() - sp := stage.NewFixedProvider(sts.Build(ctx, ec)...) - mID, err := stage.Execute(ctx, ec, any(nil), sp) + mID, err := stage.Execute(ctx, ec, "", stage.NewFixedProvider(sts.Build(ctx, ec)...)) if err != nil { return err } defer func() { - maybePrintWarnings(ctx, ec, sts.ci, mID.(string)) - finalizeErr := finalizeMigration(ctx, ec, sts.ci, mID.(string), ec.Props().GetString(flagOutputDir)) - if err == nil { - err = finalizeErr + if sts.ci != nil { + maybePrintWarnings(ctx, ec, sts.ci, mID) + finalizeErr := finalizeMigration(ctx, ec, sts.ci, mID, ec.Props().GetString(flagOutputDir)) + if err == nil { + err = finalizeErr + } } }() - mStages, err := createMigrationStages(ctx, ec, sts.ci, mID.(string)) + mStages, err := createMigrationStages(ctx, ec, sts.ci, mID) if err != nil { return err } diff --git a/base/commands/migration/status_stages.go b/base/commands/migration/status_stages.go index 252608f8..cc0bc667 100644 --- a/base/commands/migration/status_stages.go +++ b/base/commands/migration/status_stages.go @@ -18,8 +18,8 @@ func NewStatusStages() *StatusStages { return &StatusStages{} } -func (st *StatusStages) Build(ctx context.Context, ec plug.ExecContext) []stage.Stage[any] { - return []stage.Stage[any]{ +func (st *StatusStages) Build(ctx context.Context, ec plug.ExecContext) []stage.Stage[string] { + return []stage.Stage[string]{ { ProgressMsg: "Connecting to the migration cluster", SuccessMsg: "Connected to the migration cluster", @@ -35,22 +35,22 @@ func (st *StatusStages) Build(ctx context.Context, ec plug.ExecContext) []stage. } } -func (st *StatusStages) connectStage(ec plug.ExecContext) func(context.Context, stage.Statuser[any]) (any, error) { - return func(ctx context.Context, status stage.Statuser[any]) (any, error) { +func (st *StatusStages) connectStage(ec plug.ExecContext) func(context.Context, stage.Statuser[string]) (string, error) { + return func(ctx context.Context, status stage.Statuser[string]) (string, error) { var err error st.ci, err = ec.ClientInternal(ctx) if err != nil { - return nil, err + return "", err } - return nil, nil + return "", nil } } -func (st *StatusStages) findMigrationInProgress(ec plug.ExecContext) func(context.Context, stage.Statuser[any]) (any, error) { - return func(ctx context.Context, status stage.Statuser[any]) (any, error) { +func (st *StatusStages) findMigrationInProgress(ec plug.ExecContext) func(context.Context, stage.Statuser[string]) (string, error) { + return func(ctx context.Context, status stage.Statuser[string]) (string, error) { m, err := findMigrationInProgress(ctx, st.ci) if err != nil { - return nil, err + return "", err } return m.MigrationID, err }