Skip to content

Commit

Permalink
Merge pull request #1681 from Agenta-AI/tests-for-llm-apps-service
Browse files Browse the repository at this point in the history
Tests for llm apps service
  • Loading branch information
aakrem authored May 21, 2024
2 parents 4d8f12a + 2087286 commit ac064cf
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 1 deletion.
3 changes: 2 additions & 1 deletion agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def run_with_retry(
last_exception = e
logger.info(f"Error processing datapoint: {input_data}. {str(e)}")
logger.info("".join(traceback.format_exception_only(type(e), e)))
retries += 1
common.capture_exception_in_sentry(e)

# If max retries is reached or an exception that isn't in the second block,
Expand All @@ -186,7 +187,7 @@ async def run_with_retry(
result=Result(
type="error",
value=None,
error=Error(message=exception_message, stacktrace=last_exception),
error=Error(message=exception_message, stacktrace=str(last_exception)),
)
)

Expand Down
162 changes: 162 additions & 0 deletions agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import pytest
from unittest.mock import patch, AsyncMock
import asyncio
import aiohttp

from agenta_backend.services.llm_apps_service import (
batch_invoke,
InvokationResult,
Result,
Error,
)


@pytest.mark.asyncio
async def test_batch_invoke_success():
"""
Test the successful invocation of batch_invoke function.
This test mocks the get_parameters_from_openapi and invoke_app functions
to simulate successful invocations. It verifies that the batch_invoke
function correctly returns the expected results for the given test data.
"""
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to always succeed
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
return InvokationResult(
result=Result(type="text", value="Success", error=None),
latency=0.1,
cost=0.01,
)

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [
{"id": 1, "param1": "value1", "param2": "value2"},
{"id": 2, "param1": "value1", "param2": "value2"},
]
parameters = {}
rate_limit_config = {
"batch_size": 10,
"max_retries": 3,
"retry_delay": 3,
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 2
assert results[0].result.type == "text"
assert results[0].result.value == "Success"
assert results[1].result.type == "text"
assert results[1].result.value == "Success"


@pytest.mark.asyncio
async def test_batch_invoke_retries_and_failure():
"""
Test the batch_invoke function with retries and eventual failure.
This test mocks the get_parameters_from_openapi and invoke_app functions
to simulate failures that trigger retries. It verifies that the batch_invoke
function correctly retries the specified number of times and returns an error
result after reaching the maximum retries.
"""
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to always fail
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
raise aiohttp.ClientError("Test Error")

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [
{"id": 1, "param1": "value1", "param2": "value2"},
{"id": 2, "param1": "value1", "param2": "value2"},
]
parameters = {}
rate_limit_config = {
"batch_size": 10,
"max_retries": 3,
"retry_delay": 3,
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 2
assert results[0].result.type == "error"
assert results[0].result.error.message == "Max retries reached"
assert results[1].result.type == "error"
assert results[1].result.error.message == "Max retries reached"


@pytest.mark.asyncio
async def test_batch_invoke_generic_exception():
"""
Test the batch_invoke function with a generic exception.
This test mocks the get_parameters_from_openapi and invoke_app functions
to simulate a generic exception during invocation. It verifies that the
batch_invoke function correctly handles the exception and returns an error
result with the appropriate error message.
"""
with patch(
"agenta_backend.services.llm_apps_service.get_parameters_from_openapi",
new_callable=AsyncMock,
) as mock_get_parameters_from_openapi, patch(
"agenta_backend.services.llm_apps_service.invoke_app", new_callable=AsyncMock
) as mock_invoke_app, patch(
"asyncio.sleep", new_callable=AsyncMock
) as mock_sleep:
mock_get_parameters_from_openapi.return_value = [
{"name": "param1", "type": "input"},
{"name": "param2", "type": "input"},
]

# Mock the response of invoke_app to raise a generic exception
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
raise Exception("Generic Error")

mock_invoke_app.side_effect = invoke_app_side_effect

uri = "http://example.com"
testset_data = [{"id": 1, "param1": "value1", "param2": "value2"}]
parameters = {}
rate_limit_config = {
"batch_size": 1,
"max_retries": 3,
"retry_delay": 1,
"delay_between_batches": 1,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)

assert len(results) == 1
assert results[0].result.type == "error"
assert results[0].result.error.message == "Max retries reached"

0 comments on commit ac064cf

Please sign in to comment.