diff --git a/agenta-backend/agenta_backend/services/aggregation_service.py b/agenta-backend/agenta_backend/services/aggregation_service.py new file mode 100644 index 0000000000..d657bdf4de --- /dev/null +++ b/agenta-backend/agenta_backend/services/aggregation_service.py @@ -0,0 +1,75 @@ +import re +import traceback +from typing import List + +from agenta_backend.models.db_models import Result, Error + + +def aggregate_ai_critique(results: List[Result]) -> Result: + """Aggregates the results for the ai critique evaluation. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + numeric_scores = [] + for result in results: + # Extract the first number found in the result value + match = re.search(r"\d+", result.value) + if match: + try: + score = int(match.group()) + numeric_scores.append(score) + except ValueError: + # Ignore if the extracted value is not an integer + continue + + # Calculate the average of numeric scores if any are present + average_value = ( + sum(numeric_scores) / len(numeric_scores) if numeric_scores else None + ) + return Result( + type="number", + value=average_value, + ) + + +def aggregate_binary(results: List[Result]) -> Result: + """Aggregates the results for the binary (auto regex) evaluation. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + if all(isinstance(result.value, bool) for result in results): + average_value = sum(int(result.value) for result in results) / len(results) + else: + average_value = None + return Result(type="number", value=average_value) + + +def aggregate_float(results: List[Result]) -> Result: + """Aggregates the results for evaluations aside from auto regex and ai critique. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + try: + average_value = sum(result.value for result in results) / len(results) + return Result(type="number", value=average_value) + except Exception as exc: + return Result( + type="error", + value=None, + error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), + ) diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 0c387ac236..cfe05d06e3 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -23,6 +23,7 @@ evaluators_service, llm_apps_service, deployment_manager, + aggregation_service, ) from agenta_backend.services.db_manager import ( create_new_evaluation_scenario, @@ -312,45 +313,25 @@ async def aggregate_evaluator_results( if not results: result = Result(type="error", value=None, error=Error(message="-")) - else: - if evaluator_key == "auto_ai_critique": - numeric_scores = [] - for result in results: - # Extract the first number found in the result value - match = re.search(r"\d+", result.value) - if match: - try: - score = int(match.group()) - numeric_scores.append(score) - except ValueError: - # Ignore if the extracted value is not an integer - continue - - # Calculate the average of numeric scores if any are present - average_value = ( - sum(numeric_scores) / len(numeric_scores) - if numeric_scores - else None - ) - result = Result( - type="number", - value=average_value, - ) + continue - else: - # Handle boolean values for auto_regex_test and other evaluators - if all(isinstance(result.value, bool) for result in results): - average_value = sum(result.value for result in results) / len( - results - ) - else: - # Handle other data types or mixed results - average_value = None + if evaluator_key == "auto_ai_critique": + result = aggregation_service.aggregate_ai_critique(results) - result = Result( - type="number", - value=average_value, - ) + elif evaluator_key == "auto_regex_test": + result = aggregation_service.aggregate_binary(results) + + elif evaluator_key in [ + "auto_exact_match", + "auto_similarity_match", + "field_match_test", + "auto_webhook_test", + "auto_custom_code_run", + ]: + result = aggregation_service.aggregate_float(results) + + else: + raise Exception(f"Evaluator {evaluator_key} aggregation does not exist") evaluator_config = await fetch_evaluator_config(config_id) aggregated_result = AggregatedResult(