From 7e24866de9515ec57c6a59077cf0526be1410386 Mon Sep 17 00:00:00 2001 From: aakrem Date: Thu, 9 May 2024 17:58:59 +0200 Subject: [PATCH] fix evaluators tests --- .../tests/unit/test_evaluators.py | 80 +++++++++++++------ 1 file changed, 55 insertions(+), 25 deletions(-) diff --git a/agenta-backend/agenta_backend/tests/unit/test_evaluators.py b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py index 0b427c19b4..bda05163f6 100644 --- a/agenta-backend/agenta_backend/tests/unit/test_evaluators.py +++ b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py @@ -14,10 +14,42 @@ @pytest.mark.parametrize( "output, settings_values, expected", [ - ("Hello world", {"prefix": "He", "case_sensitive": True}, True), - ("hello world", {"prefix": "He", "case_sensitive": False}, True), - ("Hello world", {"prefix": "he", "case_sensitive": False}, True), - ("Hello world", {"prefix": "world", "case_sensitive": True}, False), + ( + "Hello world", + { + "prefix": "He", + "case_sensitive": True, + "correct_answer_keys": "correct_answer", + }, + True, + ), + ( + "hello world", + { + "prefix": "He", + "case_sensitive": False, + "correct_answer_keys": "correct_answer", + }, + True, + ), + ( + "Hello world", + { + "prefix": "he", + "case_sensitive": False, + "correct_answer_keys": "correct_answer", + }, + True, + ), + ( + "Hello world", + { + "prefix": "world", + "case_sensitive": True, + "correct_answer_keys": "correct_answer", + }, + False, + ), ], ) def test_auto_starts_with(output, settings_values, expected): @@ -25,7 +57,6 @@ def test_auto_starts_with(output, settings_values, expected): inputs={}, output=output, data_point={}, - correct_answer_key="", app_params={}, settings_values=settings_values, lm_providers_keys={}, @@ -50,7 +81,6 @@ def test_auto_ends_with(output, suffix, case_sensitive, expected): {}, output, {}, - "correct_answer", {}, {"suffix": suffix, "case_sensitive": case_sensitive}, {}, @@ -74,7 +104,6 @@ def test_auto_contains(output, substring, case_sensitive, expected): {}, output, {}, - "correct_answer", {}, {"substring": substring, "case_sensitive": case_sensitive}, {}, @@ -99,7 +128,6 @@ def test_auto_contains_any(output, substrings, case_sensitive, expected): {}, output, {}, - "correct_answer", {}, {"substrings": substrings, "case_sensitive": case_sensitive}, {}, @@ -124,7 +152,6 @@ def test_auto_contains_all(output, substrings, case_sensitive, expected): {}, output, {}, - "correct_answer", {}, {"substrings": substrings, "case_sensitive": case_sensitive}, {}, @@ -143,53 +170,56 @@ def test_auto_contains_all(output, substrings, case_sensitive, expected): ], ) def test_auto_contains_json(output, expected): - result = auto_contains_json({}, output, {}, "", {}, {}, {}) + result = auto_contains_json({}, output, {}, {}, {}, {}) assert result.value == expected @pytest.mark.parametrize( - "output, data_point, correct_answer_key, settings_values, expected", + "output, data_point, settings_values, expected", [ ( "hello world", {"correct_answer": "hello world"}, - "correct_answer", - {"threshold": 5}, + {"threshold": 5, "correct_answer_keys": "correct_answer"}, True, ), ( "hello world", {"correct_answer": "hola mundo"}, - "correct_answer", - {"threshold": 5}, + {"threshold": 5, "correct_answer_keys": "correct_answer"}, False, ), ( "hello world", {"correct_answer": "hello world!"}, - "correct_answer", - {"threshold": 2}, + {"threshold": 2, "correct_answer_keys": "correct_answer"}, True, ), ( "hello world", {"correct_answer": "hello wor"}, - "correct_answer", - {"threshold": 10}, + {"threshold": 10, "correct_answer_keys": "correct_answer"}, True, ), - ("hello world", {"correct_answer": "hello worl"}, "correct_answer", {}, 1), - ("hello world", {"correct_answer": "helo world"}, "correct_answer", {}, 1), + ( + "hello world", + {"correct_answer": "hello worl"}, + {"correct_answer_keys": "correct_answer"}, + 1, + ), + ( + "hello world", + {"correct_answer": "helo world"}, + {"correct_answer_keys": "correct_answer"}, + 1, + ), ], ) -def test_auto_levenshtein_distance( - output, data_point, correct_answer_key, settings_values, expected -): +def test_auto_levenshtein_distance(output, data_point, settings_values, expected): result = auto_levenshtein_distance( inputs={}, output=output, data_point=data_point, - correct_answer_key=correct_answer_key, app_params={}, settings_values=settings_values, lm_providers_keys={},