-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2090 from Agenta-AI/refactor/project-structure
[feature] Projects Structure - Checkpoint 2
- Loading branch information
Showing
46 changed files
with
1,838 additions
and
978 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
agenta-backend/agenta_backend/migrations/postgres/data_migrations/applications.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
import uuid | ||
import traceback | ||
from typing import Optional | ||
|
||
|
||
import click | ||
from sqlalchemy.future import select | ||
from sqlalchemy import create_engine, delete | ||
from sqlalchemy.orm import sessionmaker, Session | ||
|
||
from agenta_backend.models.deprecated_models import ( | ||
DeprecatedEvaluatorConfigDB, | ||
DeprecatedAppDB, | ||
) | ||
|
||
|
||
BATCH_SIZE = 1000 | ||
|
||
|
||
def get_app_db(session: Session, app_id: str) -> Optional[DeprecatedAppDB]: | ||
query = session.execute(select(DeprecatedAppDB).filter_by(id=uuid.UUID(app_id))) | ||
return query.scalars().first() | ||
|
||
|
||
def update_evaluators_with_app_name(): | ||
engine = create_engine(os.getenv("POSTGRES_URI")) | ||
sync_session = sessionmaker(engine, expire_on_commit=False) | ||
|
||
with sync_session() as session: | ||
try: | ||
offset = 0 | ||
while True: | ||
records = ( | ||
session.execute( | ||
select(DeprecatedEvaluatorConfigDB) | ||
.filter(DeprecatedEvaluatorConfigDB.app_id.isnot(None)) | ||
.offset(offset) | ||
.limit(BATCH_SIZE) | ||
) | ||
.scalars() | ||
.all() | ||
) | ||
if not records: | ||
break | ||
|
||
# Update records with app_name as prefix | ||
for record in records: | ||
evaluator_config_app = get_app_db( | ||
session=session, app_id=str(record.app_id) | ||
) | ||
if record.app_id is not None and evaluator_config_app is not None: | ||
record.name = f"{record.name} ({evaluator_config_app.app_name})" | ||
|
||
session.commit() | ||
offset += BATCH_SIZE | ||
|
||
# Delete deprecated evaluator configs with app_id as None | ||
session.execute( | ||
delete(DeprecatedEvaluatorConfigDB).where( | ||
DeprecatedEvaluatorConfigDB.app_id.is_(None) | ||
) | ||
) | ||
session.commit() | ||
except Exception as e: | ||
session.rollback() | ||
click.echo( | ||
click.style( | ||
f"ERROR updating evaluator config names: {traceback.format_exc()}", | ||
fg="red", | ||
) | ||
) | ||
raise e |
188 changes: 188 additions & 0 deletions
188
agenta-backend/agenta_backend/migrations/postgres/data_migrations/projects.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import os | ||
import traceback | ||
from typing import Sequence | ||
|
||
|
||
import click | ||
from sqlalchemy.future import select | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker, Session | ||
|
||
from agenta_backend.models.db_models import ( | ||
ProjectDB, | ||
AppDB, | ||
AppVariantDB, | ||
AppVariantRevisionsDB, | ||
VariantBaseDB, | ||
DeploymentDB, | ||
ImageDB, | ||
AppEnvironmentDB, | ||
AppEnvironmentRevisionDB, | ||
EvaluationScenarioDB, | ||
EvaluationDB, | ||
EvaluatorConfigDB, | ||
HumanEvaluationDB, | ||
HumanEvaluationScenarioDB, | ||
TestSetDB, | ||
) | ||
|
||
|
||
BATCH_SIZE = 1000 | ||
MODELS = [ | ||
AppDB, | ||
AppVariantDB, | ||
AppVariantRevisionsDB, | ||
VariantBaseDB, | ||
DeploymentDB, | ||
ImageDB, | ||
AppEnvironmentDB, | ||
AppEnvironmentRevisionDB, | ||
EvaluationScenarioDB, | ||
EvaluationDB, | ||
EvaluatorConfigDB, | ||
HumanEvaluationDB, | ||
HumanEvaluationScenarioDB, | ||
TestSetDB, | ||
] | ||
|
||
|
||
def get_default_projects(session): | ||
query = session.execute(select(ProjectDB).filter_by(is_default=True)) | ||
return query.scalars().all() | ||
|
||
|
||
def check_for_multiple_default_projects(session: Session) -> Sequence[ProjectDB]: | ||
default_projects = get_default_projects(session) | ||
if len(default_projects) > 1: | ||
raise ValueError( | ||
"Multiple default projects found. Please ensure only one exists." | ||
) | ||
return default_projects | ||
|
||
|
||
def create_default_project(): | ||
PROJECT_NAME = "Default Project" | ||
engine = create_engine(os.getenv("POSTGRES_URI")) | ||
sync_session = sessionmaker(engine, expire_on_commit=False) | ||
|
||
with sync_session() as session: | ||
try: | ||
default_projects = check_for_multiple_default_projects(session) | ||
if len(default_projects) == 0: | ||
new_project = ProjectDB(project_name=PROJECT_NAME, is_default=True) | ||
session.add(new_project) | ||
session.commit() | ||
|
||
except Exception as e: | ||
session.rollback() | ||
click.echo( | ||
click.style( | ||
f"ERROR creating default project: {traceback.format_exc()}", | ||
fg="red", | ||
) | ||
) | ||
raise e | ||
|
||
|
||
def remove_default_project(): | ||
engine = create_engine(os.getenv("POSTGRES_URI")) | ||
sync_session = sessionmaker(engine, expire_on_commit=False) | ||
|
||
with sync_session() as session: | ||
try: | ||
default_projects = check_for_multiple_default_projects(session) | ||
if len(default_projects) == 0: | ||
click.echo( | ||
click.style("No default project found to remove.", fg="yellow") | ||
) | ||
return | ||
|
||
session.delete(default_projects[0]) | ||
session.commit() | ||
click.echo(click.style("Default project removed successfully.", fg="green")) | ||
|
||
except Exception as e: | ||
session.rollback() | ||
click.echo(click.style(f"ERROR: {traceback.format_exc()}", fg="red")) | ||
raise e | ||
|
||
|
||
def add_project_id_to_db_entities(): | ||
engine = create_engine(os.getenv("POSTGRES_URI")) | ||
sync_session = sessionmaker(engine, expire_on_commit=False) | ||
|
||
with sync_session() as session: | ||
try: | ||
default_project = check_for_multiple_default_projects(session)[0] | ||
for model in MODELS: | ||
offset = 0 | ||
while True: | ||
records = ( | ||
session.execute( | ||
select(model) | ||
.where(model.project_id == None) | ||
.offset(offset) | ||
.limit(BATCH_SIZE) | ||
) | ||
.scalars() | ||
.all() | ||
) | ||
if not records: | ||
break | ||
|
||
# Update records with default project_id | ||
for record in records: | ||
record.project_id = default_project.id | ||
|
||
session.commit() | ||
offset += BATCH_SIZE | ||
|
||
except Exception as e: | ||
session.rollback() | ||
click.echo( | ||
click.style( | ||
f"ERROR adding project_id to db entities: {traceback.format_exc()}", | ||
fg="red", | ||
) | ||
) | ||
raise e | ||
|
||
|
||
def remove_project_id_from_db_entities(): | ||
engine = create_engine(os.getenv("POSTGRES_URI")) | ||
sync_session = sessionmaker(engine, expire_on_commit=False) | ||
|
||
with sync_session() as session: | ||
try: | ||
for model in MODELS: | ||
offset = 0 | ||
while True: | ||
records = ( | ||
session.execute( | ||
select(model) | ||
.where(model.project_id != None) | ||
.offset(offset) | ||
.limit(BATCH_SIZE) | ||
) | ||
.scalars() | ||
.all() | ||
) | ||
if not records: | ||
break | ||
|
||
# Update records project_id column with None | ||
for record in records: | ||
record.project_id = None | ||
|
||
session.commit() | ||
offset += BATCH_SIZE | ||
|
||
except Exception as e: | ||
session.rollback() | ||
click.echo( | ||
click.style( | ||
f"ERROR removing project_id to db entities: {traceback.format_exc()}", | ||
fg="red", | ||
) | ||
) | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
...ckend/migrations/postgres/versions/22d29365f5fc_update_evaluators_names_with_app_name_.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Update evaluators names with app name as prefix | ||
Revision ID: 22d29365f5fc | ||
Revises: 6cfe239894fb | ||
Create Date: 2024-09-16 11:38:33.886908 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
from agenta_backend.migrations.postgres.data_migrations.applications import ( | ||
update_evaluators_with_app_name, | ||
) | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "22d29365f5fc" | ||
down_revision: Union[str, None] = "6cfe239894fb" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### custom command ### | ||
update_evaluators_with_app_name() | ||
# ### end custom command ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### custom command ### | ||
pass | ||
# ### end custom command ### |
Oops, something went wrong.