diff --git a/agenta-backend/agenta_backend/migrations/postgres/utils.py b/agenta-backend/agenta_backend/migrations/postgres/utils.py index a327e32f55..4395845347 100644 --- a/agenta-backend/agenta_backend/migrations/postgres/utils.py +++ b/agenta-backend/agenta_backend/migrations/postgres/utils.py @@ -5,16 +5,13 @@ import click import asyncpg - -from sqlalchemy import inspect, text, Engine -from sqlalchemy.exc import ProgrammingError -from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine - from alembic import command +from sqlalchemy import Engine from alembic.config import Config +from sqlalchemy import inspect, text from alembic.script import ScriptDirectory - -from agenta_backend.utils.common import isCloudEE, isCloudDev +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine # Initializer logger @@ -56,15 +53,15 @@ def is_initial_setup(engine) -> bool: return not all_tables_exist -async def get_applied_migrations(engine: AsyncEngine): +async def get_current_migration_head_from_db(engine: AsyncEngine): """ - Checks the alembic_version table to get all the migrations that has been applied. + Checks the alembic_version table to get the current migration head that has been applied. Args: engine (Engine): The engine that connects to an sqlalchemy pool Returns: - a list of strings + the current migration head (where 'head' is the revision stored in the migration script) """ async with engine.connect() as connection: @@ -76,32 +73,37 @@ async def get_applied_migrations(engine: AsyncEngine): # to make Alembic start tracking the migration changes. # -------------------------------------------------------------------------------------- # This effect (the exception raising) happens for both users (first-time and returning) - return ["alembic_version"] + return "alembic_version" - applied_migrations = [row[0] for row in result.fetchall()] - return applied_migrations + migration_heads = [row[0] for row in result.fetchall()] + assert ( + len(migration_heads) == 1 + ), "There can only be one migration head stored in the database." + return migration_heads[0] -async def get_pending_migrations(): +async def get_pending_migration_head(): """ - Gets the migrations that have not been applied. + Gets the migration head that have not been applied. Returns: - the number of pending migrations + the pending migration head """ engine = create_async_engine(url=os.environ["POSTGRES_URI"]) try: - applied_migrations = await get_applied_migrations(engine=engine) - migration_files = [script.revision for script in script.walk_revisions()] - pending_migrations = [m for m in migration_files if m not in applied_migrations] - - if "alembic_version" in applied_migrations: - pending_migrations.append("alembic_version") + current_migration_script_head = script.get_current_head() + migration_head_from_db = await get_current_migration_head_from_db(engine=engine) + + pending_migration_head = [] + if current_migration_script_head != migration_head_from_db: + pending_migration_head.append(current_migration_script_head) + if "alembic_version" == migration_head_from_db: + pending_migration_head.append("alembic_version") finally: await engine.dispose() - return pending_migrations + return pending_migration_head def run_alembic_migration(): @@ -110,9 +112,9 @@ def run_alembic_migration(): """ try: - pending_migrations = asyncio.run(get_pending_migrations()) + pending_migration_head = asyncio.run(get_pending_migration_head()) APPLY_AUTO_MIGRATIONS = os.environ.get("AGENTA_AUTO_MIGRATIONS") - FIRST_TIME_USER = True if "alembic_version" in pending_migrations else False + FIRST_TIME_USER = True if "alembic_version" in pending_migration_head else False if FIRST_TIME_USER or APPLY_AUTO_MIGRATIONS == "true": command.upgrade(alembic_cfg, "head") @@ -134,7 +136,7 @@ def run_alembic_migration(): except Exception as e: click.echo( click.style( - f"\nAn ERROR occured while applying migration: {traceback.format_exc()}\nThe container will now exit.", + f"\nAn ERROR occurred while applying migration: {traceback.format_exc()}\nThe container will now exit.", fg="red", ), color=True, @@ -147,11 +149,11 @@ async def check_for_new_migrations(): Checks for new migrations and notify the user. """ - pending_migrations = await get_pending_migrations() - if len(pending_migrations) >= 1: + pending_migration_head = await get_pending_migration_head() + if len(pending_migration_head) >= 1 and isinstance(pending_migration_head[0], str): click.echo( click.style( - f"\nWe have detected that there are pending database migrations {pending_migrations} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n", + f"\nWe have detected that there are pending database migrations {pending_migration_head} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n", fg="yellow", ), color=True,