Skip to content

Commit

Permalink
Merge pull request #1437 from Agenta-AI/levenshtein-evaluator
Browse files Browse the repository at this point in the history
levenshtein distance
  • Loading branch information
aakrem authored Mar 25, 2024
2 parents c313cb9 + 1b57abb commit c2ca487
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 25 deletions.
28 changes: 24 additions & 4 deletions agenta-backend/agenta_backend/resources/evaluators/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
"type": "number",
"default": 0.5,
"description": "The threshold value for similarity comparison",
"min": 0,
"max": 1,
"required": True,
}
},
"description": "Similarity Match evaluator checks if the generated answer is similar to the expected answer. You need to provide the similarity threshold. It uses the Jaccard similarity to compare the answers.",
Expand All @@ -43,6 +46,7 @@
"type": "regex",
"default": "",
"description": "Pattern for regex testing (ex: ^this_word\\d{3}$)",
"required": True,
},
"regex_should_match": {
"label": "Match/Mismatch",
Expand All @@ -62,6 +66,7 @@
"type": "string",
"default": "",
"description": "The name of the field in the JSON output that you wish to evaluate",
"required": True,
}
},
"description": "JSON Field Match evaluator compares specific fields within JSON (JavaScript Object Notation) data. This matching can involve finding similarities or correspondences between fields in different JSON objects.",
Expand All @@ -76,6 +81,7 @@
"type": "text",
"default": "We have an LLM App that we want to evaluate its outputs. Based on the prompt and the parameters provided below evaluate the output based on the evaluation strategy below: Evaluation strategy: 0 to 10 0 is very bad and 10 is very good. Prompt: {llm_app_prompt_template} Inputs: country: {country} Correct Answer:{correct_answer} Evaluate this: {variant_output} Answer ONLY with one of the given grading or evaluation options.",
"description": "Template for AI critique prompts",
"required": True,
}
},
"description": "AI Critique evaluator sends the generated answer and the correct_answer to an LLM model and uses it to evaluate the correctness of the answer. You need to provide the evaluation prompt (or use the default prompt).",
Expand All @@ -90,6 +96,7 @@
"type": "code",
"default": "from typing import Dict\n\ndef evaluate(\n app_params: Dict[str, str],\n inputs: Dict[str, str],\n output: str,\n correct_answer: str\n) -> float:\n # ...\n return 0.75 # Replace with your calculated score",
"description": "Code for evaluating submissions",
"required": True,
}
},
"description": "Code Evaluation allows you to write your own evaluator in Python. You need to provide the Python code for the evaluator.",
Expand All @@ -103,6 +110,7 @@
"label": "Webhook URL",
"type": "string",
"description": "https://your-webhook-url.com",
"required": True,
},
},
"description": "Webhook test evaluator sends the generated answer and the correct_answer to a webhook and expects a response indicating the correctness of the answer. You need to provide the URL of the webhook and the response of the webhook must be between 0 and 1.",
Expand Down Expand Up @@ -132,10 +140,7 @@
"settings_template": {
"label": "Single Model Testing Settings",
"description": "Checks if the output starts with the specified prefix.",
"prefix": {
"label": "prefix",
"type": "string",
},
"prefix": {"label": "prefix", "type": "string", "required": True},
"case_sensitive": {
"label": "Case Sensitive",
"type": "boolean",
Expand All @@ -161,6 +166,7 @@
"label": "suffix",
"type": "string",
"description": "The string to match at the end of the output.",
"required": True,
},
},
"description": "Ends With evaluator checks if the output ends with a specified suffix, considering case sensitivity based on the settings.",
Expand All @@ -182,6 +188,7 @@
"label": "substring",
"type": "string",
"description": "The string to check if it is contained in the output.",
"required": True,
},
},
"description": "Contains evaluator checks if the output contains a specified substring, considering case sensitivity based on the settings.",
Expand All @@ -203,6 +210,7 @@
"label": "substrings",
"type": "string",
"description": "Provide a comma-separated list of strings to check if any is contained in the output.",
"required": True,
},
},
"description": "Contains Any evaluator checks if the output contains any of the specified substrings from a comma-separated list, considering case sensitivity based on the settings.",
Expand All @@ -224,10 +232,22 @@
"label": "substrings",
"type": "string",
"description": "Provide a comma-separated list of strings to check if all are contained in the output.",
"required": True,
},
},
"description": "Contains All evaluator checks if the output contains all of the specified substrings from a comma-separated list, considering case sensitivity based on the settings.",
},
{
"name": "Levenshtein Distance",
"key": "auto_levenshtein_distance",
"direct_use": False,
"settings_template": {
"label": "Levenshtein Distance Settings",
"description": "Evaluates the Levenshtein distance between the output and the correct answer. If a threshold is specified, it checks if the distance is below this threshold and returns a boolean value. If no threshold is specified, it returns the numerical Levenshtein distance.",
"threshold": {"label": "Threshold", "type": "number", "required": False},
},
"description": "This evaluator calculates the Levenshtein distance between the output and the correct answer. If a threshold is provided in the settings, it returns a boolean indicating whether the distance is within the threshold. If no threshold is provided, it returns the actual Levenshtein distance as a numerical value.",
},
]


Expand Down
49 changes: 49 additions & 0 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,55 @@ def auto_contains_json(
)


def levenshtein_distance(s1, s2):
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)

if len(s2) == 0:
return len(s1)

previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row

return previous_row[-1]


def auto_levenshtein_distance(
inputs: Dict[str, Any],
output: str,
correct_answer: str,
app_params: Dict[str, Any],
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any],
) -> Result:
try:
distance = levenshtein_distance(output, correct_answer)

if "threshold" in settings_values:
threshold = settings_values["threshold"]
is_within_threshold = distance <= threshold
return Result(type="bool", value=is_within_threshold)

return Result(type="number", value=distance)

except Exception as e:
return Result(
type="error",
value=None,
error=Error(
message="Error during Levenshtein threshold evaluation",
stacktrace=str(e),
),
)


def evaluate(
evaluator_key: str,
inputs: Dict[str, Any],
Expand Down
1 change: 1 addition & 0 deletions agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ async def aggregate_evaluator_results(
"auto_contains_any",
"auto_contains_all",
"auto_contains_json",
"auto_levenshtein_distance",
]:
result = aggregation_service.aggregate_float(results)

Expand Down
20 changes: 20 additions & 0 deletions agenta-backend/agenta_backend/tests/unit/test_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from agenta_backend.services.evaluators_service import (
auto_levenshtein_distance,
auto_starts_with,
auto_ends_with,
auto_contains,
Expand Down Expand Up @@ -129,3 +130,22 @@ def test_auto_contains_all(output, substrings, case_sensitive, expected):
def test_auto_contains_json(output, expected):
result = auto_contains_json({}, output, "", {}, {}, {})
assert result.value == expected


@pytest.mark.parametrize(
"output, correct_answer, threshold, expected",
[
("hello world", "hello world", 5, True),
("hello world", "hola mundo", 5, False),
("hello world", "hello world!", 2, True),
("hello world", "hello wor", 10, True),
("hello world", "hello worl", None, 1),
("hello world", "helo world", None, 1),
],
)
def test_auto_levenshtein_distance(output, correct_answer, threshold, expected):
settings_values = {"threshold": threshold} if threshold is not None else {}
result = auto_levenshtein_distance(
{}, output, correct_answer, {}, settings_values, {}
)
assert result.value == expected
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,15 @@ const DynamicFormField: React.FC<DynamicFormFieldProps> = ({
type,
default: defaultVal,
description,
min,
max,
required,
}) => {
const {appTheme} = useAppTheme()
const classes = useStyles()
const {token} = theme.useToken()

const rules: Rule[] = [{required: true, message: "This field is required"}]
const rules: Rule[] = [{required: required ?? true, message: "This field is required"}]
if (type === "regex")
rules.push({
validator: (_, value) =>
Expand Down Expand Up @@ -167,7 +170,7 @@ const DynamicFormField: React.FC<DynamicFormFieldProps> = ({
{type === "string" || type === "regex" ? (
<Input />
) : type === "number" ? (
<InputNumber min={0} max={1} step={0.1} />
<InputNumber min={min} max={max} step={0.1} />
) : type === "boolean" || type === "bool" ? (
<Switch />
) : type === "text" ? (
Expand Down Expand Up @@ -295,25 +298,6 @@ const NewEvaluatorModal: React.FC<Props> = ({
)
},
},
{
title: "Type",
dataIndex: "type",
key: "type",
render(_, record) {
const template = Object.keys(record?.settings_template || {})
.filter((key) => !!record?.settings_template[key]?.type)
.map((key) => ({
key,
...record?.settings_template[key]!,
}))

return (
<>
<Tag color={record.color}>{template[0].type}</Tag>
</>
)
},
},
{
title: "Description",
dataIndex: "description",
Expand Down
3 changes: 3 additions & 0 deletions agenta-web/src/lib/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ export interface EvaluationSettingsTemplate {
label: string
default?: ValueType
description: string
min?: number
max?: number
required?: boolean
}

export interface Evaluator {
Expand Down

0 comments on commit c2ca487

Please sign in to comment.