diff --git a/agenta-backend/agenta_backend/main.py b/agenta-backend/agenta_backend/main.py index c556824a1e..dc0885e483 100644 --- a/agenta-backend/agenta_backend/main.py +++ b/agenta-backend/agenta_backend/main.py @@ -47,12 +47,12 @@ @asynccontextmanager async def lifespan(application: FastAPI, cache=True): """ + Lifespan initializes the database engine and load the default llm templates. Args: application: FastAPI application. cache: A boolean value that indicates whether to use the cached data or not. """ - # initialize the database await DBEngine().init_db() await templates_manager.update_and_sync_templates(cache=cache) yield diff --git a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py index 8b16064d2d..cf00af533d 100644 --- a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py +++ b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py @@ -8,6 +8,16 @@ "description": "Settings for the Exact Match evaluator", }, }, + { + "name": "Contains Json", + "key": "auto_contains_json", + "direct_use": True, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the JSON output contains the specified JSON structure.", + }, + "description": "Contains Json evaluator checks if the output contains the specified JSON structure.", + }, { "name": "Similarity Match", "key": "auto_similarity_match", @@ -115,6 +125,109 @@ "description": "Settings for single model testing configurations", }, }, + { + "name": "Starts With", + "key": "auto_starts_with", + "direct_use": False, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the output starts with the specified prefix.", + "prefix": { + "label": "prefix", + "type": "string", + }, + "case_sensitive": { + "label": "Case Sensitive", + "type": "boolean", + "default": True, + }, + }, + "description": "Starts With evaluator checks if the output starts with a specified prefix, considering case sensitivity based on the settings.", + }, + { + "name": "Ends With", + "key": "auto_ends_with", + "direct_use": False, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the output ends with the specified suffix.", + "case_sensitive": { + "label": "Case Sensitive", + "type": "boolean", + "default": True, + "description": "If the evaluation should be case sensitive.", + }, + "suffix": { + "label": "suffix", + "type": "string", + "description": "The string to match at the end of the output.", + }, + }, + "description": "Ends With evaluator checks if the output ends with a specified suffix, considering case sensitivity based on the settings.", + }, + { + "name": "Contains", + "key": "auto_contains", + "direct_use": False, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the output contains the specified substring.", + "case_sensitive": { + "label": "Case Sensitive", + "type": "boolean", + "default": True, + "description": "If the evaluation should be case sensitive.", + }, + "substring": { + "label": "substring", + "type": "string", + "description": "The string to check if it is contained in the output.", + }, + }, + "description": "Contains evaluator checks if the output contains a specified substring, considering case sensitivity based on the settings.", + }, + { + "name": "Contains Any", + "key": "auto_contains_any", + "direct_use": False, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the output contains any of the specified substrings.", + "case_sensitive": { + "label": "Case Sensitive", + "type": "boolean", + "default": True, + "description": "If the evaluation should be case sensitive.", + }, + "substrings": { + "label": "substrings", + "type": "string", + "description": "Provide a comma-separated list of strings to check if any is contained in the output.", + }, + }, + "description": "Contains Any evaluator checks if the output contains any of the specified substrings from a comma-separated list, considering case sensitivity based on the settings.", + }, + { + "name": "Contains All", + "key": "auto_contains_all", + "direct_use": False, + "settings_template": { + "label": "Single Model Testing Settings", + "description": "Checks if the output contains all of the specified substrings.", + "case_sensitive": { + "label": "Case Sensitive", + "type": "boolean", + "default": True, + "description": "If the evaluation should be case sensitive.", + }, + "substrings": { + "label": "substrings", + "type": "string", + "description": "Provide a comma-separated list of strings to check if all are contained in the output.", + }, + }, + "description": "Contains All evaluator checks if the output contains all of the specified substrings from a comma-separated list, considering case sensitivity based on the settings.", + }, ] diff --git a/agenta-backend/agenta_backend/services/evaluators_service.py b/agenta-backend/agenta_backend/services/evaluators_service.py index 111d8cfd87..bab421c9c5 100644 --- a/agenta-backend/agenta_backend/services/evaluators_service.py +++ b/agenta-backend/agenta_backend/services/evaluators_service.py @@ -249,6 +249,178 @@ def auto_ai_critique( ) +def auto_starts_with( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + prefix = settings_values.get("prefix", "") + case_sensitive = settings_values.get("case_sensitive", True) + + if not case_sensitive: + output = output.lower() + prefix = prefix.lower() + + result = Result(type="bool", value=output.startswith(prefix)) + return result + except Exception as e: + return Result( + type="error", + value=None, + error=Error( + message="Error during Starts With evaluation", stacktrace=str(e) + ), + ) + + +def auto_ends_with( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + suffix = settings_values.get("suffix", "") + case_sensitive = settings_values.get("case_sensitive", True) + + if not case_sensitive: + output = output.lower() + suffix = suffix.lower() + + result = Result(type="bool", value=output.endswith(suffix)) + return result + except Exception as e: + return Result( + type="error", + value=None, + error=Error(message="Error during Ends With evaluation", stacktrace=str(e)), + ) + + +def auto_contains( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + substring = settings_values.get("substring", "") + case_sensitive = settings_values.get("case_sensitive", True) + + if not case_sensitive: + output = output.lower() + substring = substring.lower() + + result = Result(type="bool", value=substring in output) + return result + except Exception as e: + return Result( + type="error", + value=None, + error=Error(message="Error during Contains evaluation", stacktrace=str(e)), + ) + + +def auto_contains_any( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + substrings_str = settings_values.get("substrings", "") + substrings = [substring.strip() for substring in substrings_str.split(",")] + case_sensitive = settings_values.get("case_sensitive", True) + + if not case_sensitive: + output = output.lower() + substrings = [substring.lower() for substring in substrings] + + result = Result( + type="bool", value=any(substring in output for substring in substrings) + ) + return result + except Exception as e: + return Result( + type="error", + value=None, + error=Error( + message="Error during Contains Any evaluation", stacktrace=str(e) + ), + ) + + +def auto_contains_all( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + substrings_str = settings_values.get("substrings", "") + substrings = [substring.strip() for substring in substrings_str.split(",")] + case_sensitive = settings_values.get("case_sensitive", True) + + if not case_sensitive: + output = output.lower() + substrings = [substring.lower() for substring in substrings] + + result = Result( + type="bool", value=all(substring in output for substring in substrings) + ) + return result + except Exception as e: + return Result( + type="error", + value=None, + error=Error( + message="Error during Contains All evaluation", stacktrace=str(e) + ), + ) + + +def auto_contains_json( + inputs: Dict[str, Any], + output: str, + correct_answer: str, + app_params: Dict[str, Any], + settings_values: Dict[str, Any], + lm_providers_keys: Dict[str, Any], +) -> Result: + try: + try: + start_index = output.index("{") + end_index = output.rindex("}") + 1 + potential_json = output[start_index:end_index] + + json.loads(potential_json) + contains_json = True + except (ValueError, json.JSONDecodeError): + contains_json = False + + return Result(type="bool", value=contains_json) + except Exception as e: + return Result( + type="error", + value=None, + error=Error( + message="Error during Contains JSON evaluation", stacktrace=str(e) + ), + ) + + def evaluate( evaluator_key: str, inputs: Dict[str, Any], diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 54069c5527..1882f35564 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -349,11 +349,19 @@ async def aggregate_evaluator_results( "field_match_test", "auto_webhook_test", "auto_custom_code_run", + "auto_starts_with", + "auto_ends_with", + "auto_contains", + "auto_contains_any", + "auto_contains_all", + "auto_contains_json", ]: result = aggregation_service.aggregate_float(results) else: - raise Exception(f"Evaluator {evaluator_key} aggregation does not exist") + result = Result( + type="error", value=None, error=Error(message="Aggregation failed") + ) evaluator_config = await fetch_evaluator_config(config_id) aggregated_result = AggregatedResult( diff --git a/agenta-backend/agenta_backend/tests/unit/test_evaluators.py b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py new file mode 100644 index 0000000000..894233d7d4 --- /dev/null +++ b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py @@ -0,0 +1,131 @@ +import pytest + +from agenta_backend.services.evaluators_service import ( + auto_starts_with, + auto_ends_with, + auto_contains, + auto_contains_any, + auto_contains_all, + auto_contains_json, +) + + +@pytest.mark.parametrize( + "output, prefix, case_sensitive, expected", + [ + ("Hello world", "He", True, True), + ("hello world", "He", False, True), + ("Hello world", "he", False, True), + ("Hello world", "world", True, False), + ], +) +def test_auto_starts_with(output, prefix, case_sensitive, expected): + result = auto_starts_with( + {}, output, "", {}, {"prefix": prefix, "case_sensitive": case_sensitive}, {} + ) + assert result.value == expected + + +# Test for auto_ends_with + + +@pytest.mark.parametrize( + "output, suffix, case_sensitive, expected", + [ + ("Hello world", "world", True, True), + ("hello world", "World", False, True), + ("Hello world", "World", True, False), + ("Hello world", "Hello", True, False), + ], +) +def test_auto_ends_with(output, suffix, case_sensitive, expected): + result = auto_ends_with( + {}, output, "", {}, {"suffix": suffix, "case_sensitive": case_sensitive}, {} + ) + assert result.value == expected + + +# Test for auto_contains + + +@pytest.mark.parametrize( + "output, substring, case_sensitive, expected", + [ + ("Hello world", "lo wo", True, True), + ("Hello world", "LO WO", False, True), + ("Hello world", "abc", True, False), + ], +) +def test_auto_contains(output, substring, case_sensitive, expected): + result = auto_contains( + {}, + output, + "", + {}, + {"substring": substring, "case_sensitive": case_sensitive}, + {}, + ) + assert result.value == expected + + +# Test for auto_contains_any + + +@pytest.mark.parametrize( + "output, substrings, case_sensitive, expected", + [ + ("Hello world", "hello,world", True, True), + ("Hello world", "world,universe", True, True), + ("Hello world", "world,universe", False, True), + ("Hello world", "abc,xyz", True, False), + ], +) +def test_auto_contains_any(output, substrings, case_sensitive, expected): + result = auto_contains_any( + {}, + output, + "", + {}, + {"substrings": substrings, "case_sensitive": case_sensitive}, + {}, + ) + assert result.value == expected + + +# Test for auto_contains_all + + +@pytest.mark.parametrize( + "output, substrings, case_sensitive, expected", + [ + ("Hello world", "hello,world", True, False), + ("Hello world", "Hello,world", True, True), + ("Hello world", "hello,world", False, True), + ("Hello world", "world,universe", True, False), + ], +) +def test_auto_contains_all(output, substrings, case_sensitive, expected): + result = auto_contains_all( + {}, + output, + "", + {}, + {"substrings": substrings, "case_sensitive": case_sensitive}, + {}, + ) + assert result.value == expected + + +# Test for auto_contains_json +@pytest.mark.parametrize( + "output, expected", + [ + ('Some random text {"key": "value"} more text', True), + ("No JSON here!", False), + ("{Malformed JSON, nope!}", False), + ('{"valid": "json", "number": 123}', True), + ], +) +def test_auto_contains_json(output, expected): + result = auto_contains_json({}, output, "", {}, {}, {}) + assert result.value == expected diff --git a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py index 51e47fe295..e45b17fbda 100644 --- a/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py +++ b/agenta-backend/agenta_backend/tests/variants_main_router/test_variant_evaluators_router.py @@ -49,7 +49,7 @@ async def test_get_evaluators_endpoint(): timeout=timeout, ) assert response.status_code == 200 - assert len(response.json()) == 9 # currently we have 9 evaluators + assert len(response.json()) > 0 @pytest.mark.asyncio @@ -237,7 +237,7 @@ async def test_fetch_evaluation_results(): assert response.status_code == 200 assert response_data["evaluation_id"] == str(evaluation.id) - assert len(response_data["results"]) == 6 + assert len(response_data["results"]) == 7 @pytest.mark.asyncio diff --git a/agenta-web/cypress/e2e/eval.evaluators.cy.ts b/agenta-web/cypress/e2e/eval.evaluators.cy.ts index 7f33ea1331..0708d157d5 100644 --- a/agenta-web/cypress/e2e/eval.evaluators.cy.ts +++ b/agenta-web/cypress/e2e/eval.evaluators.cy.ts @@ -17,7 +17,7 @@ describe("Evaluators CRUD Operations Test", function () { }) it("Should successfully create an Evaluator", () => { - cy.get('[data-cy="evaluator-card"]').should("have.length", 1) + cy.get('[data-cy="evaluator-card"]').should("exist") cy.get(".ant-space > :nth-child(2) > .ant-btn").click() cy.get('[data-cy="new-evaluator-modal"]').should("exist") cy.get('[data-cy^="select-new-evaluator"]').eq(0).click() diff --git a/agenta-web/src/media/bracket-curly.png b/agenta-web/src/media/bracket-curly.png new file mode 100644 index 0000000000..9f82842021 Binary files /dev/null and b/agenta-web/src/media/bracket-curly.png differ diff --git a/agenta-web/src/services/evaluations/index.ts b/agenta-web/src/services/evaluations/index.ts index 54d7b3fa0f..ae7b3b506a 100644 --- a/agenta-web/src/services/evaluations/index.ts +++ b/agenta-web/src/services/evaluations/index.ts @@ -20,6 +20,7 @@ import regexImg from "@/media/programming.png" import webhookImg from "@/media/link.png" import aiImg from "@/media/artificial-intelligence.png" import codeImg from "@/media/browser.png" +import bracketCurlyImg from "@/media/bracket-curly.png" import dayjs from "dayjs" import {loadTestset} from "@/lib/services/api" import {runningStatuses} from "@/components/pages/evaluations/cellRenderers/cellRenderers" @@ -40,6 +41,7 @@ const evaluatorIconsMap = { auto_webhook_test: webhookImg, auto_ai_critique: aiImg, auto_custom_code_run: codeImg, + auto_contains_json: bracketCurlyImg, } //Evaluators