From 14c8ee2aad1d54e6543b6cc17b9fe338586eab56 Mon Sep 17 00:00:00 2001 From: aakrem Date: Wed, 15 May 2024 19:16:03 +0200 Subject: [PATCH] small refactor for correct answers logic --- .../agenta_backend/tasks/evaluations.py | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 8807c0d70f..540009922f 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -178,11 +178,6 @@ def evaluate( ) for evaluator_config_db in evaluator_config_dbs ] - correct_answer = ( - data_point[correct_answer_column] - if correct_answer_column in data_point - else "" - ) loop.run_until_complete( create_new_evaluation_scenario( @@ -195,7 +190,7 @@ def evaluate( inputs=inputs, is_pinned=False, note="", - correct_answer=correct_answer, + correct_answer=None, outputs=[ EvaluationScenarioOutputDB( result=Result( @@ -215,30 +210,13 @@ def evaluate( # 3. We evaluate evaluators_results: List[EvaluationScenarioResult] = [] - correct_answers = [] - # Loop over each evaluator configuration to gather the correct answers + # Loop over each evaluator configuration to gather the correct answers and evaluate + all_correct_answers = [] for evaluator_config_db in evaluator_config_dbs: - correct_answer_keys = evaluator_config_db.settings_values.get( - "correct_answer_keys" - ) - if not correct_answer_keys: - return { - "type": "error", - "value": None, - "error": "No correct answer keys provided.", - } - - ## In case one evaluator has multiple correct answers - correct_answer_keys_list = [ - key.strip() for key in correct_answer_keys.split(",") - ] + correct_answers = parse_correct_answers(evaluator_config_db, data_point) - for key in correct_answer_keys_list: - correct_answer_value = data_point.get(key, "") - - correct_answer = CorrectAnswer(key=key, value=correct_answer_value) - correct_answers.append(correct_answer) + all_correct_answers.extend(correct_answers) logger.debug(f"Evaluating with evaluator: {evaluator_config_db}") @@ -276,7 +254,7 @@ def evaluate( inputs=inputs, is_pinned=False, note="", - correct_answers=correct_answers, + correct_answers=all_correct_answers, outputs=[ EvaluationScenarioOutputDB( result=Result(type="text", value=app_output.result.value), @@ -439,3 +417,23 @@ def get_app_inputs(app_variant_parameters, openapi_parameters) -> List[Dict[str, elif param["type"] == "file_url": list_inputs.append({"name": param["name"], "type": "file_url"}) return list_inputs + + +def parse_correct_answers(evaluator_config_db, data_point) -> List[CorrectAnswer]: + """ + Extracts correct answers based on the evaluator configuration. + """ + correct_answers = [] + correct_answer_keys = evaluator_config_db.settings_values.get("correct_answer_keys") + + if not correct_answer_keys: + return [] + + # In case one evaluator has multiple correct answers + correct_answer_keys_list = [key.strip() for key in correct_answer_keys.split(",")] + + for key in correct_answer_keys_list: + correct_answer_value = data_point.get(key, "") + correct_answers.append(CorrectAnswer(key=key, value=correct_answer_value)) + + return correct_answers