From b5300df25ca6f58d8524ae477a1eeca5693ef88d Mon Sep 17 00:00:00 2001 From: aakrem Date: Fri, 17 May 2024 22:09:04 +0200 Subject: [PATCH] add default correct answer in case its not provided --- .../resources/evaluators/evaluators.py | 15 +++++++++++++++ .../agenta_backend/services/evaluator_manager.py | 11 +++++++++++ 2 files changed, 26 insertions(+) diff --git a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py index 97a24600ed..976e1f55dc 100644 --- a/agenta-backend/agenta_backend/resources/evaluators/evaluators.py +++ b/agenta-backend/agenta_backend/resources/evaluators/evaluators.py @@ -267,3 +267,18 @@ def get_all_evaluators(): List[dict]: A list of evaluator dictionaries. """ return evaluators + + +def get_evaluator_by_key(key: str): + """ + Returns an evaluator with the specified key + + Args: + key (str): The key of the evaluator to retrieve + + Returns: + dict or None: The evaluator dictionary if found, otherwise None + """ + return next( + (evaluator for evaluator in evaluators if evaluator["key"] == key), None + ) diff --git a/agenta-backend/agenta_backend/services/evaluator_manager.py b/agenta-backend/agenta_backend/services/evaluator_manager.py index c86a4ef946..de262987cc 100644 --- a/agenta-backend/agenta_backend/services/evaluator_manager.py +++ b/agenta-backend/agenta_backend/services/evaluator_manager.py @@ -17,6 +17,7 @@ from agenta_backend.models.converters import evaluator_config_db_to_pydantic from agenta_backend.resources.evaluators.evaluators import get_all_evaluators from agenta_backend.models.api.evaluation_model import Evaluator, EvaluatorConfig +from agenta_backend.resources.evaluators import evaluators def get_evaluators() -> Optional[List[Evaluator]]: @@ -79,6 +80,16 @@ async def create_evaluator_config( EvaluatorConfigDB: The newly created evaluator configuration object. """ app = await db_manager.fetch_app_by_id(app_id) + evaluator_config = evaluators.get_evaluator_by_key(evaluator_key) + + if evaluator_config is not None: + if "correct_answer_keys" in evaluator_config.get("settings_template", {}): + if settings_values is None: + settings_values = {} + settings_values["correct_answer_keys"] = evaluator_config[ + "settings_template" + ]["correct_answer_keys"]["default"] + evaluator_config = await db_manager.create_evaluator_config( app=app, organization=app.organization if isCloudEE() else None, # noqa,