From 7352677b5a5c0ee97794c5a844e3d326459bb8c0 Mon Sep 17 00:00:00 2001 From: Akrem Abayed Date: Thu, 7 Dec 2023 15:46:47 +0100 Subject: [PATCH] improve logic for default testset --- ...testsets.json => chat_openai_testset.json} | 0 .../agenta_backend/services/db_manager.py | 42 ++++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) rename agenta-backend/agenta_backend/resources/default_testsets/{single_prompt_testsets.json => chat_openai_testset.json} (100%) diff --git a/agenta-backend/agenta_backend/resources/default_testsets/single_prompt_testsets.json b/agenta-backend/agenta_backend/resources/default_testsets/chat_openai_testset.json similarity index 100% rename from agenta-backend/agenta_backend/resources/default_testsets/single_prompt_testsets.json rename to agenta-backend/agenta_backend/resources/default_testsets/chat_openai_testset.json diff --git a/agenta-backend/agenta_backend/services/db_manager.py b/agenta-backend/agenta_backend/services/db_manager.py index ee0bedc982..538132e87a 100644 --- a/agenta-backend/agenta_backend/services/db_manager.py +++ b/agenta-backend/agenta_backend/services/db_manager.py @@ -65,23 +65,33 @@ async def add_testset_to_app_variant( **kwargs (dict): Additional keyword arguments """ - app_db = await get_app_instance_by_id(app_id) - org_db = await get_organization_object(org_id) - user_db = await get_user(user_uid=kwargs["uid"]) - - if template_name == "chat_openai": - json_path = ( - f"{PARENT_DIRECTORY}/resources/default_testsets/single_prompt_testsets.json" + try: + app_db = await get_app_instance_by_id(app_id) + org_db = await get_organization_object(org_id) + user_db = await get_user(user_uid=kwargs["uid"]) + + json_path = os.path.join( + PARENT_DIRECTORY, + "resources", + "default_testsets", + f"{template_name}_testset.json", ) - csvdata = get_json(json_path) - testset = { - "name": f"{app_name}_testset", - "app_name": app_name, - "created_at": datetime.now().isoformat(), - "csvdata": csvdata, - } - testset = TestSetDB(**testset, app=app_db, user=user_db, organization=org_db) - await engine.save(testset) + + if os.path.exists(json_path): + csvdata = get_json(json_path) + testset = { + "name": f"{app_name}_testset", + "app_name": app_name, + "created_at": datetime.now().isoformat(), + "csvdata": csvdata, + } + testset_db = TestSetDB( + **testset, app=app_db, user=user_db, organization=org_db + ) + await engine.save(testset_db) + + except Exception as e: + print(f"An error occurred in adding the default testset: {e}") async def get_image(app_variant: AppVariant, **kwargs: dict) -> ImageExtended: