From 95b82ef2f2b2e5fd7feb6b66fd35abf8e1adae30 Mon Sep 17 00:00:00 2001 From: aakrem Date: Thu, 16 May 2024 09:23:21 +0200 Subject: [PATCH] convert correct_answer_keys to list --- .../resources/evaluators/evaluators.py | 14 ++++++------- .../agenta_backend/tasks/evaluations.py | 5 +---- .../tests/unit/test_evaluators.py | 20 +++++++++---------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py index 21e67eee86..97a24600ed 100644 --- a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py +++ b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py @@ -8,8 +8,8 @@ "description": "Settings for the Exact Match evaluator", "correct_answer_keys": { "label": "Correct Answer", - "default": "correct_answer", - "type": "string", + "default": ["correct_answer"], + "type": "array", }, }, "description": "Exact Match evaluator determines if the output exactly matches the specified correct answer, ensuring precise alignment with expected results.", @@ -40,8 +40,8 @@ }, "correct_answer_keys": { "label": "Correct Answer", - "default": "correct_answer", - "type": "string", + "default": ["correct_answer"], + "type": "array", }, }, "description": "Similarity Match evaluator checks if the generated answer is similar to the expected answer. You need to provide the similarity threshold. It uses the Jaccard similarity to compare the answers.", @@ -81,7 +81,7 @@ }, "correct_answer_keys": { "label": "Correct Answer", - "default": "correct_answer", + "default": ["correct_answer"], "type": "string", }, }, @@ -130,7 +130,7 @@ }, "correct_answer_keys": { "label": "Correct Answer", - "default": "correct_answer", + "default": ["correct_answer"], "type": "string", }, }, @@ -250,7 +250,7 @@ "threshold": {"label": "Threshold", "type": "number", "required": False}, "correct_answer_keys": { "label": "Correct Answer", - "default": "correct_answer", + "default": ["correct_answer"], "type": "string", }, }, diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 540009922f..381aa18e25 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -429,10 +429,7 @@ def parse_correct_answers(evaluator_config_db, data_point) -> List[CorrectAnswer if not correct_answer_keys: return [] - # In case one evaluator has multiple correct answers - correct_answer_keys_list = [key.strip() for key in correct_answer_keys.split(",")] - - for key in correct_answer_keys_list: + for key in correct_answer_keys: correct_answer_value = data_point.get(key, "") correct_answers.append(CorrectAnswer(key=key, value=correct_answer_value)) diff --git a/agenta-backend/agenta_backend/tests/unit/test_evaluators.py b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py index bda05163f6..a8e5e58391 100644 --- a/agenta-backend/agenta_backend/tests/unit/test_evaluators.py +++ b/agenta-backend/agenta_backend/tests/unit/test_evaluators.py @@ -19,7 +19,7 @@ { "prefix": "He", "case_sensitive": True, - "correct_answer_keys": "correct_answer", + "correct_answer_keys": ["correct_answer"], }, True, ), @@ -28,7 +28,7 @@ { "prefix": "He", "case_sensitive": False, - "correct_answer_keys": "correct_answer", + "correct_answer_keys": ["correct_answer"], }, True, ), @@ -37,7 +37,7 @@ { "prefix": "he", "case_sensitive": False, - "correct_answer_keys": "correct_answer", + "correct_answer_keys": ["correct_answer"], }, True, ), @@ -46,7 +46,7 @@ { "prefix": "world", "case_sensitive": True, - "correct_answer_keys": "correct_answer", + "correct_answer_keys": ["correct_answer"], }, False, ), @@ -180,37 +180,37 @@ def test_auto_contains_json(output, expected): ( "hello world", {"correct_answer": "hello world"}, - {"threshold": 5, "correct_answer_keys": "correct_answer"}, + {"threshold": 5, "correct_answer_keys": ["correct_answer"]}, True, ), ( "hello world", {"correct_answer": "hola mundo"}, - {"threshold": 5, "correct_answer_keys": "correct_answer"}, + {"threshold": 5, "correct_answer_keys": ["correct_answer"]}, False, ), ( "hello world", {"correct_answer": "hello world!"}, - {"threshold": 2, "correct_answer_keys": "correct_answer"}, + {"threshold": 2, "correct_answer_keys": ["correct_answer"]}, True, ), ( "hello world", {"correct_answer": "hello wor"}, - {"threshold": 10, "correct_answer_keys": "correct_answer"}, + {"threshold": 10, "correct_answer_keys": ["correct_answer"]}, True, ), ( "hello world", {"correct_answer": "hello worl"}, - {"correct_answer_keys": "correct_answer"}, + {"correct_answer_keys": ["correct_answer"]}, 1, ), ( "hello world", {"correct_answer": "helo world"}, - {"correct_answer_keys": "correct_answer"}, + {"correct_answer_keys": ["correct_answer"]}, 1, ), ],