diff --git a/agenta-backend/agenta_backend/routers/app_variant.py b/agenta-backend/agenta_backend/routers/app_variant.py index 2c30a2ddd9..8db9184f8a 100644 --- a/agenta-backend/agenta_backend/routers/app_variant.py +++ b/agenta-backend/agenta_backend/routers/app_variant.py @@ -61,6 +61,48 @@ async def list_app_variants( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/get_variant_by_name/", response_model=AppVariant) +async def get_variant_by_name( + app_name: str, + variant_name: str, + stoken_session: SessionContainer = Depends(verify_session()), +): + """Fetches a specific app variant based on the given app_name and variant_name. + + Arguments: + app_name (str): The name of the app to query. + variant_name (str): The name of the variant to query. + + Raises: + HTTPException: Raises 404 if no matching variant is found, + 400 for ValueError, or 500 for any other exceptions. + + Returns: + AppVariant: The fetched app variant. + """ + + try: + # Retrieve the user and organization ID based on the session token + kwargs = await get_user_and_org_id(stoken_session) + + # Fetch the app variant using the provided app_name and variant_name + app_variant = await db_manager.get_app_variant_by_app_name_and_variant_name( + app_name=app_name, variant_name=variant_name, **kwargs + ) + # Check if the fetched app variant is None and raise 404 if it is + if app_variant is None: + raise HTTPException(status_code=500, detail="App Variant not found") + return app_variant + except ValueError as e: + # Handle ValueErrors and return 400 status code + raise HTTPException(status_code=400, detail=str(e)) + except HTTPException as e: + raise e + except Exception as e: + # Handle all other exceptions and return 500 status code + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/list_apps/", response_model=List[App]) async def list_apps( stoken_session: SessionContainer = Depends(verify_session()), diff --git a/agenta-backend/agenta_backend/services/db_manager.py b/agenta-backend/agenta_backend/services/db_manager.py index 93e28bac3b..3e456840b0 100644 --- a/agenta-backend/agenta_backend/services/db_manager.py +++ b/agenta-backend/agenta_backend/services/db_manager.py @@ -237,6 +237,54 @@ async def list_app_variants( return app_variants +async def get_app_variant_by_app_name_and_variant_name( + app_name: str, variant_name: str, show_soft_deleted: bool = False, **kwargs: dict +) -> AppVariant: + """Fetches an app variant based on app_name and variant_name. + + Args: + app_name (str): Name of the app. + variant_name (str): Name of the variant. + show_soft_deleted: if true, returns soft deleted variants as well + **kwargs (dict): Additional keyword arguments. + + Returns: + AppVariant: The fetched app variant. + """ + + # Get the user object using the user ID + user = await get_user_object(kwargs["uid"]) + + # Construct the base query for the user + users_query = query.eq(AppVariantDB.user_id, user.id) + + # Construct the query for soft-deleted items + soft_delete_query = query.eq(AppVariantDB.is_deleted, show_soft_deleted) + + # Construct the final query filters + query_filters = ( + query.eq(AppVariantDB.app_name, app_name) + & query.eq(AppVariantDB.variant_name, variant_name) + & users_query + & soft_delete_query + ) + + # Perform the database query + app_variants_db = await engine.find( + AppVariantDB, + query_filters, + sort=(AppVariantDB.app_name, AppVariantDB.variant_name), + ) + + # Convert the database object to AppVariant and return it + # Assuming that find will return a list, take the first element if it exists + app_variant: AppVariant = ( + app_variant_db_to_pydantic(app_variants_db[0]) if app_variants_db else None + ) + + return app_variant + + async def list_apps(**kwargs: dict) -> List[App]: """ Lists all the unique app names from the database diff --git a/agenta-backend/tests/test_router_app_variant.py b/agenta-backend/tests/test_router_app_variant.py index c1dac99068..65ee36b6e7 100644 --- a/agenta-backend/tests/test_router_app_variant.py +++ b/agenta-backend/tests/test_router_app_variant.py @@ -82,6 +82,13 @@ def test_list_app_variant(): assert response.json() == [] +def test_get_variant_by_name_with_invalid_names(): + response = client.get( + "/app_variant/get_variant_by_name/?app_name=invalid&variant_name=invalid" + ) + assert response.status_code == 500 + + def test_list_app_variant_after_manual_add(app_variant, image): # This is the function from db_manager.py add_variant_based_on_image(app_variant, image) diff --git a/agenta-cli/agenta/client/client.py b/agenta-cli/agenta/client/client.py index 21cf23186f..125410ff9c 100644 --- a/agenta-cli/agenta/client/client.py +++ b/agenta-cli/agenta/client/client.py @@ -83,6 +83,29 @@ def list_variants(app_name: str, host: str) -> List[AppVariant]: return [AppVariant(**variant) for variant in app_variants] +def get_variant_by_name(app_name: str, variant_name: str, host: str) -> AppVariant: + """Gets a variant by name + + Arguments: + app_name -- the app name + variant_name -- the variant name + + Returns: + the variant using the pydantic model + """ + response = requests.get( + f"{host}/{BACKEND_URL_SUFFIX}/app_variant/get_variant_by_name/?app_name={app_name}&variant_name={variant_name}", + timeout=600, + ) + + # Check for successful request + if response.status_code != 200: + error_message = response.json() + raise APIRequestError( + f"Request to get_variant_by_name endpoint failed with status code {response.status_code} and error message: {error_message}." + ) + + def remove_variant(app_name: str, variant_name: str, host: str): """Removes a variant from the backend