Skip to content

Commit

Permalink
small refactor for correct answers logic
Browse files Browse the repository at this point in the history
  • Loading branch information
aakrem committed May 15, 2024
1 parent 34f2436 commit 14c8ee2
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,6 @@ def evaluate(
)
for evaluator_config_db in evaluator_config_dbs
]
correct_answer = (
data_point[correct_answer_column]
if correct_answer_column in data_point
else ""
)

loop.run_until_complete(
create_new_evaluation_scenario(
Expand All @@ -195,7 +190,7 @@ def evaluate(
inputs=inputs,
is_pinned=False,
note="",
correct_answer=correct_answer,
correct_answer=None,
outputs=[
EvaluationScenarioOutputDB(
result=Result(
Expand All @@ -215,30 +210,13 @@ def evaluate(

# 3. We evaluate
evaluators_results: List[EvaluationScenarioResult] = []
correct_answers = []

# Loop over each evaluator configuration to gather the correct answers
# Loop over each evaluator configuration to gather the correct answers and evaluate
all_correct_answers = []
for evaluator_config_db in evaluator_config_dbs:
correct_answer_keys = evaluator_config_db.settings_values.get(
"correct_answer_keys"
)
if not correct_answer_keys:
return {
"type": "error",
"value": None,
"error": "No correct answer keys provided.",
}

## In case one evaluator has multiple correct answers
correct_answer_keys_list = [
key.strip() for key in correct_answer_keys.split(",")
]
correct_answers = parse_correct_answers(evaluator_config_db, data_point)

for key in correct_answer_keys_list:
correct_answer_value = data_point.get(key, "")

correct_answer = CorrectAnswer(key=key, value=correct_answer_value)
correct_answers.append(correct_answer)
all_correct_answers.extend(correct_answers)

logger.debug(f"Evaluating with evaluator: {evaluator_config_db}")

Expand Down Expand Up @@ -276,7 +254,7 @@ def evaluate(
inputs=inputs,
is_pinned=False,
note="",
correct_answers=correct_answers,
correct_answers=all_correct_answers,
outputs=[
EvaluationScenarioOutputDB(
result=Result(type="text", value=app_output.result.value),
Expand Down Expand Up @@ -439,3 +417,23 @@ def get_app_inputs(app_variant_parameters, openapi_parameters) -> List[Dict[str,
elif param["type"] == "file_url":
list_inputs.append({"name": param["name"], "type": "file_url"})
return list_inputs


def parse_correct_answers(evaluator_config_db, data_point) -> List[CorrectAnswer]:
"""
Extracts correct answers based on the evaluator configuration.
"""
correct_answers = []
correct_answer_keys = evaluator_config_db.settings_values.get("correct_answer_keys")

if not correct_answer_keys:
return []

# In case one evaluator has multiple correct answers
correct_answer_keys_list = [key.strip() for key in correct_answer_keys.split(",")]

for key in correct_answer_keys_list:
correct_answer_value = data_point.get(key, "")
correct_answers.append(CorrectAnswer(key=key, value=correct_answer_value))

return correct_answers

0 comments on commit 14c8ee2

Please sign in to comment.