Skip to content

Commit

Permalink
Merge pull request #583 from yeokyeong-yanolja/feature/get_variant_by…
Browse files Browse the repository at this point in the history
…_name#571

feature : develop Endpoint to Fetch LLM Application Variant Parameters
  • Loading branch information
mmabrouk authored Sep 12, 2023
2 parents 38e4228 + 1e23277 commit 85d14c5
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 0 deletions.
42 changes: 42 additions & 0 deletions agenta-backend/agenta_backend/routers/app_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,48 @@ async def list_app_variants(
raise HTTPException(status_code=500, detail=str(e))


@router.get("/get_variant_by_name/", response_model=AppVariant)
async def get_variant_by_name(
app_name: str,
variant_name: str,
stoken_session: SessionContainer = Depends(verify_session()),
):
"""Fetches a specific app variant based on the given app_name and variant_name.
Arguments:
app_name (str): The name of the app to query.
variant_name (str): The name of the variant to query.
Raises:
HTTPException: Raises 404 if no matching variant is found,
400 for ValueError, or 500 for any other exceptions.
Returns:
AppVariant: The fetched app variant.
"""

try:
# Retrieve the user and organization ID based on the session token
kwargs = await get_user_and_org_id(stoken_session)

# Fetch the app variant using the provided app_name and variant_name
app_variant = await db_manager.get_app_variant_by_app_name_and_variant_name(
app_name=app_name, variant_name=variant_name, **kwargs
)
# Check if the fetched app variant is None and raise 404 if it is
if app_variant is None:
raise HTTPException(status_code=500, detail="App Variant not found")
return app_variant
except ValueError as e:
# Handle ValueErrors and return 400 status code
raise HTTPException(status_code=400, detail=str(e))
except HTTPException as e:
raise e
except Exception as e:
# Handle all other exceptions and return 500 status code
raise HTTPException(status_code=500, detail=str(e))


@router.get("/list_apps/", response_model=List[App])
async def list_apps(
stoken_session: SessionContainer = Depends(verify_session()),
Expand Down
48 changes: 48 additions & 0 deletions agenta-backend/agenta_backend/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,54 @@ async def list_app_variants(
return app_variants


async def get_app_variant_by_app_name_and_variant_name(
app_name: str, variant_name: str, show_soft_deleted: bool = False, **kwargs: dict
) -> AppVariant:
"""Fetches an app variant based on app_name and variant_name.
Args:
app_name (str): Name of the app.
variant_name (str): Name of the variant.
show_soft_deleted: if true, returns soft deleted variants as well
**kwargs (dict): Additional keyword arguments.
Returns:
AppVariant: The fetched app variant.
"""

# Get the user object using the user ID
user = await get_user_object(kwargs["uid"])

# Construct the base query for the user
users_query = query.eq(AppVariantDB.user_id, user.id)

# Construct the query for soft-deleted items
soft_delete_query = query.eq(AppVariantDB.is_deleted, show_soft_deleted)

# Construct the final query filters
query_filters = (
query.eq(AppVariantDB.app_name, app_name)
& query.eq(AppVariantDB.variant_name, variant_name)
& users_query
& soft_delete_query
)

# Perform the database query
app_variants_db = await engine.find(
AppVariantDB,
query_filters,
sort=(AppVariantDB.app_name, AppVariantDB.variant_name),
)

# Convert the database object to AppVariant and return it
# Assuming that find will return a list, take the first element if it exists
app_variant: AppVariant = (
app_variant_db_to_pydantic(app_variants_db[0]) if app_variants_db else None
)

return app_variant


async def list_apps(**kwargs: dict) -> List[App]:
"""
Lists all the unique app names from the database
Expand Down
7 changes: 7 additions & 0 deletions agenta-backend/tests/test_router_app_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def test_list_app_variant():
assert response.json() == []


def test_get_variant_by_name_with_invalid_names():
response = client.get(
"/app_variant/get_variant_by_name/?app_name=invalid&variant_name=invalid"
)
assert response.status_code == 500


def test_list_app_variant_after_manual_add(app_variant, image):
# This is the function from db_manager.py
add_variant_based_on_image(app_variant, image)
Expand Down
23 changes: 23 additions & 0 deletions agenta-cli/agenta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,29 @@ def list_variants(app_name: str, host: str) -> List[AppVariant]:
return [AppVariant(**variant) for variant in app_variants]


def get_variant_by_name(app_name: str, variant_name: str, host: str) -> AppVariant:
"""Gets a variant by name
Arguments:
app_name -- the app name
variant_name -- the variant name
Returns:
the variant using the pydantic model
"""
response = requests.get(
f"{host}/{BACKEND_URL_SUFFIX}/app_variant/get_variant_by_name/?app_name={app_name}&variant_name={variant_name}",
timeout=600,
)

# Check for successful request
if response.status_code != 200:
error_message = response.json()
raise APIRequestError(
f"Request to get_variant_by_name endpoint failed with status code {response.status_code} and error message: {error_message}."
)


def remove_variant(app_name: str, variant_name: str, host: str):
"""Removes a variant from the backend
Expand Down

0 comments on commit 85d14c5

Please sign in to comment.