Skip to content

Commit

Permalink
Perf: increase batch invoke by rewriting http call logic using aiohttp
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed Apr 23, 2024
1 parent 8bea241 commit 3cf5577
Showing 1 changed file with 36 additions and 49 deletions.
85 changes: 36 additions & 49 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import httpx
import logging
import asyncio
import traceback
import aiohttp
from typing import Any, Dict, List


Expand Down Expand Up @@ -74,19 +73,17 @@ async def invoke_app(
InvokationResult: The output of the app.
Raises:
httpx.HTTPError: If the POST request fails.
aiohttp.ClientError: If the POST request fails.
"""
url = f"{uri}/generate"
payload = await make_payload(datapoint, parameters, openapi_parameters)

async with httpx.AsyncClient() as client:
async with aiohttp.ClientSession() as client:
try:
logger.debug(f"Invoking app {uri} with payload {payload}")
response = await client.post(
url, json=payload, timeout=httpx.Timeout(timeout=5, read=None, write=5)
)
response = await client.post(url, json=payload, timeout=5)
response.raise_for_status()
app_response = response.json()
app_response = await response.json()
return InvokationResult(
result=Result(
type="text",
Expand All @@ -97,17 +94,11 @@ async def invoke_app(
cost=app_response.get("cost"),
)

except httpx.HTTPStatusError as e:
except aiohttp.ClientResponseError as e:
# Parse error details from the API response
error_message = "Error in invoking the LLM App:"
try:
error_body = e.response.json()
if "message" in error_body:
error_message = error_body["message"]
elif (
"error" in error_body
): # Some APIs return error information under an 'error' key
error_message = error_body["error"]
error_message = e.message
except ValueError:
# Fallback if the error response is not JSON or doesn't have the expected structure
logger.error(f"Failed to parse error response: {e}")
Expand All @@ -117,20 +108,7 @@ async def invoke_app(
result=Result(
type="error",
error=Error(
message=error_message,
stacktrace=str(e),
),
)
)

except httpx.RequestError as e:
# Handle other request errors (e.g., network issues)
logger.error(f"Request error: {e}")
return InvokationResult(
result=Result(
type="error",
error=Error(
message="Network error while invoking the LLM App",
message=f"{e.code}: {error_message}",
stacktrace=str(e),
),
)
Expand Down Expand Up @@ -179,19 +157,27 @@ async def run_with_retry(
try:
result = await invoke_app(uri, input_data, parameters, openapi_parameters)
return result
except (httpx.TimeoutException, httpx.ConnectTimeout, httpx.ConnectError) as e:
except aiohttp.ClientError as e:
last_exception = e
print(f"Error in evaluation. Retrying in {retry_delay} seconds:", e)
await asyncio.sleep(retry_delay)
retries += 1

# If max retries reached, return the last exception
# return AppOutput(output=None, status=str(last_exception))
except Exception as e:
last_exception = e
logger.info(f"Error processing datapoint: {input_data}")

# If max retries is reached or an exception that isn't in the second block,
# update & return the last exception
exception_message = (
"Max retries reached"
if retries == max_retry_count
else f"Error processing {input_data} datapoint"
)
return InvokationResult(
result=Result(
type="error",
value=None,
error=Error(message="max retries reached", stacktrace=last_exception),
error=Error(message=exception_message, stacktrace=last_exception),
)
)

Expand Down Expand Up @@ -230,25 +216,27 @@ async def batch_invoke(
openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json")

async def run_batch(start_idx: int):
tasks = []
print(f"Preparing {start_idx} batch...")
end_idx = min(start_idx + batch_size, len(testset_data))
for index in range(start_idx, end_idx):
try:
batch_output: InvokationResult = await run_with_retry(
task = asyncio.ensure_future(
run_with_retry(
uri,
testset_data[index],
parameters,
max_retries,
retry_delay,
openapi_parameters,
)
list_of_app_outputs.append(batch_output)
print(f"Adding outputs to batch {start_idx}")
except Exception as exc:
traceback.print_exc()
logger.info(
f"Error processing batch[{start_idx}]:[{end_idx}] ==> {str(exc)}"
)
)
tasks.append(task)

# Gather results of all tasks
results = await asyncio.gather(*tasks)
for result in results:
list_of_app_outputs.append(result)
print(f"Adding outputs to batch {start_idx}")

# Schedule the next batch with a delay
next_batch_start_idx = end_idx
Expand Down Expand Up @@ -307,9 +295,8 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]:


async def _get_openai_json_from_uri(uri):
async with httpx.AsyncClient() as client:
resp = await client.get(uri)
timeout = httpx.Timeout(timeout=5, read=None, write=5)
resp = await client.get(uri, timeout=timeout)
json_data = json.loads(resp.text)
async with aiohttp.ClientSession() as client:
resp = await client.get(uri, timeout=5)
resp_text = await resp.text()
json_data = json.loads(resp_text)
return json_data

0 comments on commit 3cf5577

Please sign in to comment.