-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1347 from Agenta-AI/fix/custom-code-evaluation-ag…
…gregation Bug Fix: Custom code evaluation aggregation
- Loading branch information
Showing
2 changed files
with
93 additions
and
37 deletions.
There are no files selected for viewing
75 changes: 75 additions & 0 deletions
75
agenta-backend/agenta_backend/services/aggregation_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import re | ||
import traceback | ||
from typing import List | ||
|
||
from agenta_backend.models.db_models import Result, Error | ||
|
||
|
||
def aggregate_ai_critique(results: List[Result]) -> Result: | ||
"""Aggregates the results for the ai critique evaluation. | ||
Args: | ||
results (List[Result]): list of result objects | ||
Returns: | ||
Result: aggregated result | ||
""" | ||
|
||
numeric_scores = [] | ||
for result in results: | ||
# Extract the first number found in the result value | ||
match = re.search(r"\d+", result.value) | ||
if match: | ||
try: | ||
score = int(match.group()) | ||
numeric_scores.append(score) | ||
except ValueError: | ||
# Ignore if the extracted value is not an integer | ||
continue | ||
|
||
# Calculate the average of numeric scores if any are present | ||
average_value = ( | ||
sum(numeric_scores) / len(numeric_scores) if numeric_scores else None | ||
) | ||
return Result( | ||
type="number", | ||
value=average_value, | ||
) | ||
|
||
|
||
def aggregate_binary(results: List[Result]) -> Result: | ||
"""Aggregates the results for the binary (auto regex) evaluation. | ||
Args: | ||
results (List[Result]): list of result objects | ||
Returns: | ||
Result: aggregated result | ||
""" | ||
|
||
if all(isinstance(result.value, bool) for result in results): | ||
average_value = sum(int(result.value) for result in results) / len(results) | ||
else: | ||
average_value = None | ||
return Result(type="number", value=average_value) | ||
|
||
|
||
def aggregate_float(results: List[Result]) -> Result: | ||
"""Aggregates the results for evaluations aside from auto regex and ai critique. | ||
Args: | ||
results (List[Result]): list of result objects | ||
Returns: | ||
Result: aggregated result | ||
""" | ||
|
||
try: | ||
average_value = sum(result.value for result in results) / len(results) | ||
return Result(type="number", value=average_value) | ||
except Exception as exc: | ||
return Result( | ||
type="error", | ||
value=None, | ||
error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters