From 3cf5577b4fb0c8627008cd8f9dba6edb61eb226f Mon Sep 17 00:00:00 2001 From: Abram Date: Tue, 23 Apr 2024 09:39:51 +0100 Subject: [PATCH] Perf: increase batch invoke by rewriting http call logic using aiohttp --- .../services/llm_apps_service.py | 85 ++++++++----------- 1 file changed, 36 insertions(+), 49 deletions(-) diff --git a/agenta-backend/agenta_backend/services/llm_apps_service.py b/agenta-backend/agenta_backend/services/llm_apps_service.py index 7db418b60d..ea7373596c 100644 --- a/agenta-backend/agenta_backend/services/llm_apps_service.py +++ b/agenta-backend/agenta_backend/services/llm_apps_service.py @@ -1,8 +1,7 @@ import json -import httpx import logging import asyncio -import traceback +import aiohttp from typing import Any, Dict, List @@ -74,19 +73,17 @@ async def invoke_app( InvokationResult: The output of the app. Raises: - httpx.HTTPError: If the POST request fails. + aiohttp.ClientError: If the POST request fails. """ url = f"{uri}/generate" payload = await make_payload(datapoint, parameters, openapi_parameters) - async with httpx.AsyncClient() as client: + async with aiohttp.ClientSession() as client: try: logger.debug(f"Invoking app {uri} with payload {payload}") - response = await client.post( - url, json=payload, timeout=httpx.Timeout(timeout=5, read=None, write=5) - ) + response = await client.post(url, json=payload, timeout=5) response.raise_for_status() - app_response = response.json() + app_response = await response.json() return InvokationResult( result=Result( type="text", @@ -97,17 +94,11 @@ async def invoke_app( cost=app_response.get("cost"), ) - except httpx.HTTPStatusError as e: + except aiohttp.ClientResponseError as e: # Parse error details from the API response error_message = "Error in invoking the LLM App:" try: - error_body = e.response.json() - if "message" in error_body: - error_message = error_body["message"] - elif ( - "error" in error_body - ): # Some APIs return error information under an 'error' key - error_message = error_body["error"] + error_message = e.message except ValueError: # Fallback if the error response is not JSON or doesn't have the expected structure logger.error(f"Failed to parse error response: {e}") @@ -117,20 +108,7 @@ async def invoke_app( result=Result( type="error", error=Error( - message=error_message, - stacktrace=str(e), - ), - ) - ) - - except httpx.RequestError as e: - # Handle other request errors (e.g., network issues) - logger.error(f"Request error: {e}") - return InvokationResult( - result=Result( - type="error", - error=Error( - message="Network error while invoking the LLM App", + message=f"{e.code}: {error_message}", stacktrace=str(e), ), ) @@ -179,19 +157,27 @@ async def run_with_retry( try: result = await invoke_app(uri, input_data, parameters, openapi_parameters) return result - except (httpx.TimeoutException, httpx.ConnectTimeout, httpx.ConnectError) as e: + except aiohttp.ClientError as e: last_exception = e print(f"Error in evaluation. Retrying in {retry_delay} seconds:", e) await asyncio.sleep(retry_delay) retries += 1 - - # If max retries reached, return the last exception - # return AppOutput(output=None, status=str(last_exception)) + except Exception as e: + last_exception = e + logger.info(f"Error processing datapoint: {input_data}") + + # If max retries is reached or an exception that isn't in the second block, + # update & return the last exception + exception_message = ( + "Max retries reached" + if retries == max_retry_count + else f"Error processing {input_data} datapoint" + ) return InvokationResult( result=Result( type="error", value=None, - error=Error(message="max retries reached", stacktrace=last_exception), + error=Error(message=exception_message, stacktrace=last_exception), ) ) @@ -230,11 +216,12 @@ async def batch_invoke( openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json") async def run_batch(start_idx: int): + tasks = [] print(f"Preparing {start_idx} batch...") end_idx = min(start_idx + batch_size, len(testset_data)) for index in range(start_idx, end_idx): - try: - batch_output: InvokationResult = await run_with_retry( + task = asyncio.ensure_future( + run_with_retry( uri, testset_data[index], parameters, @@ -242,13 +229,14 @@ async def run_batch(start_idx: int): retry_delay, openapi_parameters, ) - list_of_app_outputs.append(batch_output) - print(f"Adding outputs to batch {start_idx}") - except Exception as exc: - traceback.print_exc() - logger.info( - f"Error processing batch[{start_idx}]:[{end_idx}] ==> {str(exc)}" - ) + ) + tasks.append(task) + + # Gather results of all tasks + results = await asyncio.gather(*tasks) + for result in results: + list_of_app_outputs.append(result) + print(f"Adding outputs to batch {start_idx}") # Schedule the next batch with a delay next_batch_start_idx = end_idx @@ -307,9 +295,8 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]: async def _get_openai_json_from_uri(uri): - async with httpx.AsyncClient() as client: - resp = await client.get(uri) - timeout = httpx.Timeout(timeout=5, read=None, write=5) - resp = await client.get(uri, timeout=timeout) - json_data = json.loads(resp.text) + async with aiohttp.ClientSession() as client: + resp = await client.get(uri, timeout=5) + resp_text = await resp.text() + json_data = json.loads(resp_text) return json_data