Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mmabrouk committed May 28, 2024
1 parent 1a6e4ce commit 6b44f5c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
19 changes: 13 additions & 6 deletions agenta-backend/agenta_backend/resources/evaluators/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
"name": "Contains Json",
"key": "auto_contains_json",
"direct_use": True,
"settings_template": {
},
"settings_template": {},
"description": "Contains Json evaluator checks if the output contains the specified JSON structure.",
},
{
Expand All @@ -45,7 +44,6 @@
"ground_truth_key": True, # Tells the frontend that is the name of the column in the test set that should be shown as a ground truth to the user
"description": "The name of the column in the test data that contains the correct answer",
},

},
"description": "Similarity Match evaluator checks if the generated answer is similar to the expected answer. You need to provide the similarity threshold. It uses the Jaccard similarity to compare the answers.",
},
Expand Down Expand Up @@ -113,7 +111,6 @@
"ground_truth_key": True, # Tells the frontend that is the name of the column in the test set that should be shown as a ground truth to the user
"description": "The name of the column in the test data that contains the correct answer",
},

},
"description": "AI Critique evaluator sends the generated answer and the correct_answer to an LLM model and uses it to evaluate the correctness of the answer. You need to provide the evaluation prompt (or use the default prompt).",
},
Expand Down Expand Up @@ -159,7 +156,12 @@
"key": "auto_starts_with",
"direct_use": False,
"settings_template": {
"prefix": {"label": "prefix", "type": "string", "required": True, "description": "The string to match at the start of the output."},
"prefix": {
"label": "prefix",
"type": "string",
"required": True,
"description": "The string to match at the start of the output.",
},
"case_sensitive": {
"label": "Case Sensitive",
"type": "boolean",
Expand Down Expand Up @@ -254,7 +256,12 @@
"key": "auto_levenshtein_distance",
"direct_use": False,
"settings_template": {
"threshold": {"label": "Threshold", "type": "number", "required": False, "description": "The maximum allowed Levenshtein distance between the output and the correct answer."},
"threshold": {
"label": "Threshold",
"type": "number",
"required": False,
"description": "The maximum allowed Levenshtein distance between the output and the correct answer.",
},
"correct_answer_key": {
"label": "Correct Answer",
"default": "correct_answer",
Expand Down
4 changes: 3 additions & 1 deletion agenta-backend/agenta_backend/services/evaluator_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ async def create_ready_to_use_evaluators(app: AppDB):
}

for setting_name, default_value in settings_values.items():
assert default_value != "", f"Default value for ground truth key '{setting_name}' in Evaluator is empty"
assert (
default_value != ""
), f"Default value for ground truth key '{setting_name}' in Evaluator is empty"

await db_manager.create_evaluator_config(
app=app,
Expand Down
18 changes: 9 additions & 9 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
logger.setLevel(logging.DEBUG)


def get_correct_answer(data_point: Dict[str, Any], settings_values: Dict[str, Any]) -> Any:
def get_correct_answer(
data_point: Dict[str, Any], settings_values: Dict[str, Any]
) -> Any:
"""
Helper function to retrieve the correct answer from the data point based on the settings values.
Expand All @@ -31,7 +33,9 @@ def get_correct_answer(data_point: Dict[str, Any], settings_values: Dict[str, An
if correct_answer_key is None:
raise ValueError("No correct answer keys provided.")
if correct_answer_key not in data_point:
raise ValueError(f"Correct answer column '{correct_answer_key}' not found in the test set.")
raise ValueError(
f"Correct answer column '{correct_answer_key}' not found in the test set."
)
return data_point[correct_answer_key]


Expand Down Expand Up @@ -115,9 +119,7 @@ def field_match_test(
try:
correct_answer = get_correct_answer(data_point, settings_values)
output_json = json.loads(output)
result = (
output_json[settings_values["json_field"]] == correct_answer
)
result = output_json[settings_values["json_field"]] == correct_answer
return Result(type="bool", value=result)
except ValueError as e:
return Result(
Expand Down Expand Up @@ -262,14 +264,12 @@ def auto_ai_critique(
prompt_template = settings_values["prompt_template"]
messages = [
{"role": "system", "content": prompt_template},
{"role": "user", "content": str(chain_run_args)}
{"role": "user", "content": str(chain_run_args)},
]

client = OpenAI(api_key=openai_api_key)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.8
model="gpt-3.5-turbo", messages=messages, temperature=0.8
)

evaluation_output = response.choices[0].message["content"].strip()
Expand Down
12 changes: 9 additions & 3 deletions agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
all_evaluators = get_evaluators()
ground_truth_keys_dict = {
evaluator["key"]: [
key for key, value in evaluator.get("settings_template", {}).items()
key
for key, value in evaluator.get("settings_template", {}).items()
if value.get("ground_truth_key") is True
]
for evaluator in all_evaluators
Expand Down Expand Up @@ -215,7 +216,9 @@ def evaluate(
# Loop over each evaluator configuration to gather the correct answers and evaluate
ground_truth_column_names = []
for evaluator_config_db in evaluator_config_dbs:
ground_truth_keys = ground_truth_keys_dict.get(evaluator_config_db.evaluator_key, [])
ground_truth_keys = ground_truth_keys_dict.get(
evaluator_config_db.evaluator_key, []
)
ground_truth_column_names.extend(
evaluator_config_db.settings_values.get(key, "")
for key in ground_truth_keys
Expand Down Expand Up @@ -246,7 +249,10 @@ def evaluate(
evaluators_results.append(result_object)

all_correct_answers = [
CorrectAnswer(key=ground_truth_column_name, value=data_point[ground_truth_column_name])
CorrectAnswer(
key=ground_truth_column_name,
value=data_point[ground_truth_column_name],
)
if ground_truth_column_name in data_point
else CorrectAnswer(key=ground_truth_column_name, value="")
for ground_truth_column_name in ground_truth_column_names
Expand Down

0 comments on commit 6b44f5c

Please sign in to comment.