Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for llm apps service #1681

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,28 +165,47 @@ async def run_with_retry(
return result
except aiohttp.ClientError as e:
last_exception = e
print(f"Error in evaluation. Retrying in {retry_delay} seconds:", e)
error_message = f"HTTP error occurred during request: {str(e)}"
logger.error(error_message)
await asyncio.sleep(retry_delay)
retries += 1
except asyncio.TimeoutError as e:
aakrem marked this conversation as resolved.
Show resolved Hide resolved
last_exception = e
error_message = f"Request timed out: {str(e)}"
logger.error(error_message)
await asyncio.sleep(retry_delay)
retries += 1
except aiohttp.ClientConnectionError as e:
last_exception = e
error_message = f"Connection error: {str(e)}"
logger.error(error_message)
await asyncio.sleep(retry_delay)
retries += 1
except json.JSONDecodeError as e:
last_exception = e
error_message = f"Failed to decode JSON from response: {str(e)}"
logger.error(error_message)
common.capture_exception_in_sentry(e)
break # Exit the loop for non-retriable exceptions
except Exception as e:
last_exception = e
logger.info(f"Error processing datapoint: {input_data}. {str(e)}")
logger.info("".join(traceback.format_exception_only(type(e), e)))
error_message = (
f"Error processing datapoint: {input_data}. Exception: {str(e)}"
)
logger.error(error_message)
logger.error("".join(traceback.format_exception_only(type(e), e)))
common.capture_exception_in_sentry(e)
break # Exit the loop for non-ClientError exceptions
aakrem marked this conversation as resolved.
Show resolved Hide resolved

# If max retries is reached or an exception that isn't in the second block,
# update & return the last exception
logging.info("Max retries reached")
logger.info("Max retries reached or a critical error occurred")
exception_message = (
"Max retries reached"
if retries == max_retry_count
else f"Error processing {input_data} datapoint"
"Max retries reached" if retries == max_retry_count else error_message
aakrem marked this conversation as resolved.
Show resolved Hide resolved
)
return InvokationResult(
result=Result(
type="error",
value=None,
error=Error(message=exception_message, stacktrace=last_exception),
error=Error(message=exception_message, stacktrace=str(last_exception)),
)
)

Expand Down
178 changes: 178 additions & 0 deletions agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import pytest
from unittest.mock import patch, AsyncMock
import asyncio
import aiohttp
import json

from agenta_backend.services.llm_apps_service import (
batch_invoke,
InvokationResult,
Result,
Error,
)


@pytest.mark.asyncio
async def test_batch_invoke_success():
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to always succeed
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
return InvokationResult(
result=Result(type="text", value="Success", error=None),
latency=0.1,
cost=0.01,
)

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [
{"id": 1, "param1": "value1", "param2": "value2"},
{"id": 2, "param1": "value1", "param2": "value2"},
]
parameters = {}
rate_limit_config = {
"batch_size": 10,
"max_retries": 3,
"retry_delay": 3,
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 2
assert results[0].result.type == "text"
assert results[0].result.value == "Success"
assert results[1].result.type == "text"
assert results[1].result.value == "Success"


@pytest.mark.asyncio
async def test_batch_invoke_retries_and_failure():
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to always fail
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
raise aiohttp.ClientError("Test Error")

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [
{"id": 1, "param1": "value1", "param2": "value2"},
{"id": 2, "param1": "value1", "param2": "value2"},
]
parameters = {}
rate_limit_config = {
"batch_size": 10,
"max_retries": 3,
"retry_delay": 3,
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 2
assert results[0].result.type == "error"
assert results[0].result.error.message == "Max retries reached"
assert results[1].result.type == "error"
assert results[1].result.error.message == "Max retries reached"


@pytest.mark.asyncio
async def test_batch_invoke_json_decode_error():
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to raise json.JSONDecodeError
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
raise json.JSONDecodeError("Expecting value", "", 0)

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [{"id": 1, "param1": "value1", "param2": "value2"}]
parameters = {}
rate_limit_config = {
"batch_size": 1,
"max_retries": 3,
"retry_delay": 1,
"delay_between_batches": 1,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 1
assert results[0].result.type == "error"
assert "Failed to decode JSON from response" in results[0].result.error.message


@pytest.mark.asyncio
async def test_batch_invoke_generic_exception():
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to raise a generic exception
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
raise Exception("Generic Error")

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [{"id": 1, "param1": "value1", "param2": "value2"}]
parameters = {}
rate_limit_config = {
"batch_size": 1,
"max_retries": 3,
"retry_delay": 1,
"delay_between_batches": 1,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 1
assert results[0].result.type == "error"
assert "Generic Error" in results[0].result.error.message
Loading