diff --git a/agenta-backend/agenta_backend/routers/app_router.py b/agenta-backend/agenta_backend/routers/app_router.py index a41c92b1e5..5ddda04154 100644 --- a/agenta-backend/agenta_backend/routers/app_router.py +++ b/agenta-backend/agenta_backend/routers/app_router.py @@ -8,8 +8,6 @@ from agenta_backend.services.selectors import get_user_own_org from agenta_backend.services import ( app_manager, - docker_utils, - container_manager, db_manager, ) from agenta_backend.utils.common import ( @@ -24,6 +22,7 @@ AppVariantOutput, AddVariantFromImagePayload, EnvironmentOutput, + Image, ) from agenta_backend.models import converters @@ -34,6 +33,17 @@ else: from agenta_backend.services.selectors import get_user_and_org_id +if os.environ["FEATURE_FLAG"] in ["cloud"]: + from agenta_backend.ee.services import ( + lambda_deployment_manager as deployment_manager, + ) # noqa pylint: disable-all +elif os.environ["FEATURE_FLAG"] in ["ee"]: + from agenta_backend.ee.services import ( + deployment_manager, + ) # noqa pylint: disable-all +else: + from agenta_backend.services import deployment_manager + router = APIRouter() logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -221,12 +231,16 @@ async def add_variant_from_image( """ if os.environ["FEATURE_FLAG"] not in ["cloud", "ee"]: + image = Image( + docker_id=payload.docker_id, + tags=payload.tags, + ) if not payload.tags.startswith(settings.registry): raise HTTPException( status_code=500, detail="Image should have a tag starting with the registry name (agenta-server)", ) - elif docker_utils.find_image_by_docker_id(payload.docker_id) is None: + elif deployment_manager.validate_image(image) is False: raise HTTPException(status_code=404, detail="Image not found") try: diff --git a/agenta-backend/agenta_backend/routers/container_router.py b/agenta-backend/agenta_backend/routers/container_router.py index 97d3bd7e00..5073945646 100644 --- a/agenta-backend/agenta_backend/routers/container_router.py +++ b/agenta-backend/agenta_backend/routers/container_router.py @@ -7,7 +7,6 @@ Template, ) from agenta_backend.services import db_manager -from agenta_backend.services.docker_utils import restart_container from fastapi import APIRouter, Request, UploadFile, HTTPException from fastapi.responses import JSONResponse @@ -18,12 +17,13 @@ else: from agenta_backend.services.selectors import get_user_and_org_id -if os.environ["FEATURE_FLAG"] in ["cloud", "ee"]: +if os.environ["FEATURE_FLAG"] in ["cloud"]: + from agenta_backend.ee.services import container_manager +if os.environ["FEATURE_FLAG"] in ["ee"]: from agenta_backend.ee.services import container_manager else: from agenta_backend.services import container_manager - import logging logger = logging.getLogger(__name__) @@ -93,7 +93,7 @@ async def restart_docker_container( container_id = deployment.container_id logger.debug(f"Restarting container with id: {container_id}") - restart_container(container_id) + container_manager.restart_container(container_id) return {"message": "Please wait a moment. The container is now restarting."} except Exception as ex: return JSONResponse({"message": str(ex)}, status_code=500) diff --git a/agenta-backend/agenta_backend/services/app_manager.py b/agenta-backend/agenta_backend/services/app_manager.py index f05cd505a7..bfd6d84a2f 100644 --- a/agenta-backend/agenta_backend/services/app_manager.py +++ b/agenta-backend/agenta_backend/services/app_manager.py @@ -20,6 +20,10 @@ from agenta_backend.ee.services import ( lambda_deployment_manager as deployment_manager, ) # noqa pylint: disable-all +elif os.environ["FEATURE_FLAG"] in ["ee"]: + from agenta_backend.ee.services import ( + deployment_manager, + ) # noqa pylint: disable-all else: from agenta_backend.services import deployment_manager diff --git a/agenta-backend/agenta_backend/services/container_manager.py b/agenta-backend/agenta_backend/services/container_manager.py index 63485460c9..f34b025224 100644 --- a/agenta-backend/agenta_backend/services/container_manager.py +++ b/agenta-backend/agenta_backend/services/container_manager.py @@ -19,6 +19,7 @@ from agenta_backend.models.db_models import ( AppDB, ) +from agenta_backend.services import docker_utils client = docker.from_env() @@ -127,61 +128,6 @@ def build_image_job( raise HTTPException(status_code=500, detail=str(ex)) -@backoff.on_exception(backoff.expo, (ConnectError, CancelledError), max_tries=5) -async def retrieve_templates_from_dockerhub( - url: str, repo_owner: str, repo_name: str -) -> Union[List[dict], dict]: - """ - Business logic to retrieve templates from DockerHub. - - Args: - url (str): The URL endpoint for retrieving templates. Should contain placeholders `{}` - for the `repo_owner` and `repo_name` values to be inserted. For example: - `https://hub.docker.com/v2/repositories/{}/{}/tags`. - repo_owner (str): The owner or organization of the repository from which templates are to be retrieved. - repo_name (str): The name of the repository where the templates are located. - - Returns: - tuple: A tuple containing two values. - """ - - async with httpx.AsyncClient() as client: - response = await client.get( - f"{url.format(repo_owner, repo_name)}/tags", timeout=10 - ) - if response.status_code == 200: - response_data = response.json() - return response_data - - response_data = response.json() - return response_data - - -@backoff.on_exception( - backoff.expo, (ConnectError, TimeoutException, CancelledError), max_tries=5 -) -async def get_templates_info_from_s3(url: str) -> Dict[str, Dict[str, Any]]: - """ - Business logic to retrieve templates information from S3. - - Args: - url (str): The URL endpoint for retrieving templates info. - - Returns: - response_data (Dict[str, Dict[str, Any]]): A dictionary \ - containing dictionaries of templates information. - """ - - async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=10) - if response.status_code == 200: - response_data = response.json() - return response_data - - response_data = response.json() - return response_data - - async def check_docker_arch() -> str: """Checks the architecture of the Docker system. @@ -243,3 +189,12 @@ async def get_image_details_from_docker_hub( f"{repo_owner}/{repo_name}:{image_name}" ) return image_details["Id"] + + +def restart_container(container_id: str): + """Restart docker container. + + Args: + container_id (str): The id of the container to restart. + """ + docker_utils.restart_container(container_id) diff --git a/agenta-backend/agenta_backend/services/templates_manager.py b/agenta-backend/agenta_backend/services/templates_manager.py index 84b33fe3a8..28b793ce00 100644 --- a/agenta-backend/agenta_backend/services/templates_manager.py +++ b/agenta-backend/agenta_backend/services/templates_manager.py @@ -1,9 +1,18 @@ import json +import backoff from typing import Any, Dict, List - +import httpx +import os from agenta_backend.config import settings -from agenta_backend.services import container_manager, db_manager +from agenta_backend.services import db_manager from agenta_backend.utils import redis_utils +from httpx import ConnectError, TimeoutException +from asyncio.exceptions import CancelledError + +if os.environ["FEATURE_FLAG"] in ["oss", "cloud"]: + from agenta_backend.services import container_manager + +from typing import Union async def update_and_sync_templates(cache: bool = True) -> None: @@ -75,7 +84,7 @@ async def retrieve_templates_from_dockerhub_cached(cache: bool) -> List[dict]: return json.loads(cached_data.decode("utf-8")) # If not cached, fetch data from Docker Hub and cache it in Redis - response = await container_manager.retrieve_templates_from_dockerhub( + response = await retrieve_templates_from_dockerhub( settings.docker_hub_url, settings.docker_hub_repo_owner, settings.docker_hub_repo_name, @@ -107,7 +116,7 @@ async def retrieve_templates_info_from_s3( return json.loads(cached_data) # If not cached, fetch data from Docker Hub and cache it in Redis - response = await container_manager.get_templates_info_from_s3( + response = await get_templates_info_from_s3( "https://llm-app-json.s3.eu-central-1.amazonaws.com/llm_info.json" ) @@ -115,3 +124,58 @@ async def retrieve_templates_info_from_s3( r.set("temp_data", json.dumps(response), ex=900) print("Using network call...") return response + + +@backoff.on_exception(backoff.expo, (ConnectError, CancelledError), max_tries=5) +async def retrieve_templates_from_dockerhub( + url: str, repo_owner: str, repo_name: str +) -> Union[List[dict], dict]: + """ + Business logic to retrieve templates from DockerHub. + + Args: + url (str): The URL endpoint for retrieving templates. Should contain placeholders `{}` + for the `repo_owner` and `repo_name` values to be inserted. For example: + `https://hub.docker.com/v2/repositories/{}/{}/tags`. + repo_owner (str): The owner or organization of the repository from which templates are to be retrieved. + repo_name (str): The name of the repository where the templates are located. + + Returns: + tuple: A tuple containing two values. + """ + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{url.format(repo_owner, repo_name)}/tags", timeout=10 + ) + if response.status_code == 200: + response_data = response.json() + return response_data + + response_data = response.json() + return response_data + + +@backoff.on_exception( + backoff.expo, (ConnectError, TimeoutException, CancelledError), max_tries=5 +) +async def get_templates_info_from_s3(url: str) -> Dict[str, Dict[str, Any]]: + """ + Business logic to retrieve templates information from S3. + + Args: + url (str): The URL endpoint for retrieving templates info. + + Returns: + response_data (Dict[str, Dict[str, Any]]): A dictionary \ + containing dictionaries of templates information. + """ + + async with httpx.AsyncClient() as client: + response = await client.get(url, timeout=10) + if response.status_code == 200: + response_data = response.json() + return response_data + + response_data = response.json() + return response_data