From a7f8be59ab90494683c31396b8f35b422b792f46 Mon Sep 17 00:00:00 2001 From: Akrem Abayed Date: Fri, 19 Jan 2024 17:30:23 +0100 Subject: [PATCH] fix ai critique --- .../services/evaluators_service.py | 6 +-- .../agenta_backend/tasks/evaluations.py | 45 +++++++++++-------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/agenta-backend/agenta_backend/services/evaluators_service.py b/agenta-backend/agenta_backend/services/evaluators_service.py index 15c1d7832f..372f2172f7 100644 --- a/agenta-backend/agenta_backend/services/evaluators_service.py +++ b/agenta-backend/agenta_backend/services/evaluators_service.py @@ -150,10 +150,8 @@ def auto_ai_critique( "correct_answer": correct_answer, } - for input_item in app_params.get("inputs", []): - input_name = input_item.get("name") - if input_name and input_name in inputs: - chain_run_args[input_name] = inputs[input_name] + for key, value in inputs.items(): + chain_run_args[key] = value prompt = PromptTemplate( input_variables=list(chain_run_args.keys()), # Use the keys from chain_run_args diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 6bb45e0924..dea2ee5a42 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -1,6 +1,7 @@ import asyncio import logging import os +import re import traceback from collections import defaultdict from typing import Any, Dict, List @@ -208,28 +209,34 @@ async def aggregate_evaluator_results( for config_id, val in evaluators_aggregated_data.items(): evaluator_key = val["evaluator_key"] or "" results = val["results"] or [] - if evaluator_key != "auto_ai_critique": + + if not results: + average_value = 0 + if evaluator_key == "auto_ai_critique": + numeric_scores = [] + for result in results: + # Extract the first number found in the result value + match = re.search(r"\d+", result.value) + if match: + try: + score = int(match.group()) + numeric_scores.append(score) + except ValueError: + # Ignore if the extracted value is not an integer + continue + + # Calculate the average of numeric scores if any are present average_value = ( - sum([result.value for result in results]) / len(results) - if results - else 0 + sum(numeric_scores) / len(numeric_scores) if numeric_scores else None ) - elif evaluator_key == "auto_ai_critique": - try: - average_value = ( - sum( - [ - int(result.value) - for result in results - if isinstance(int(result.value), int) - ] - ) - / len(results) - if results - else 0 - ) - except TypeError: + else: + # Handle boolean values for auto_regex_test and other evaluators + if all(isinstance(result.value, bool) for result in results): + average_value = sum(result.value for result in results) / len(results) + else: + # Handle other data types or mixed results average_value = None + evaluator_config = await fetch_evaluator_config(config_id) aggregated_result = AggregatedResult( evaluator_config=evaluator_config.id,