-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1155 from confident-ai/features/json-correctness
New metric
- Loading branch information
Showing
6 changed files
with
235 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
from typing import List, Optional, Union | ||
import json | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from deepeval.test_case import ( | ||
LLMTestCase, | ||
LLMTestCaseParams, | ||
ConversationalTestCase, | ||
) | ||
from deepeval.metrics import BaseMetric | ||
from deepeval.metrics.utils import ( | ||
construct_verbose_logs, | ||
check_llm_test_case_params, | ||
initialize_model, | ||
trimAndLoadJson, | ||
) | ||
from deepeval.models import DeepEvalBaseLLM | ||
from deepeval.metrics.indicator import metric_progress_indicator | ||
from deepeval.metrics.json_correctness.template import JsonCorrectnessTemplate | ||
from deepeval.metrics.json_correctness.schema import Reason | ||
from deepeval.utils import get_or_create_event_loop | ||
|
||
DEFAULT_CORRERCT_REASON = "The generated Json matches and is syntactically correct to the expected schema." | ||
|
||
required_params: List[LLMTestCaseParams] = [ | ||
LLMTestCaseParams.INPUT, | ||
LLMTestCaseParams.ACTUAL_OUTPUT, | ||
] | ||
|
||
|
||
class JsonCorrectnessMetric(BaseMetric): | ||
def __init__( | ||
self, | ||
expected_schema: BaseModel, | ||
model: Optional[Union[str, DeepEvalBaseLLM]] = None, | ||
threshold: float = 0.5, | ||
async_mode: bool = True, | ||
include_reason: bool = True, | ||
strict_mode: bool = True, | ||
verbose_mode: bool = False, | ||
): | ||
self.threshold = 1 if strict_mode else threshold | ||
self.model, self.using_native_model = initialize_model(model) | ||
self.include_reason = include_reason | ||
self.strict_mode = strict_mode | ||
self.async_mode = async_mode | ||
self.verbose_mode = verbose_mode | ||
self.expected_schema = expected_schema | ||
|
||
def measure( | ||
self, | ||
test_case: Union[LLMTestCase, ConversationalTestCase], | ||
_show_indicator: bool = True, | ||
) -> float: | ||
if isinstance(test_case, ConversationalTestCase): | ||
test_case = test_case.turns[0] | ||
check_llm_test_case_params(test_case, required_params, self) | ||
|
||
self.evaluation_cost = 0 if self.using_native_model else None | ||
with metric_progress_indicator(self, _show_indicator=_show_indicator): | ||
if self.async_mode: | ||
loop = get_or_create_event_loop() | ||
loop.run_until_complete( | ||
self.a_measure(test_case, _show_indicator=False) | ||
) | ||
else: | ||
valid_json = True | ||
try: | ||
self.expected_schema.model_validate_json( | ||
test_case.actual_output | ||
) | ||
except ValidationError as e: | ||
valid_json = False | ||
|
||
self.score = 1 if valid_json else 0 | ||
self.reason = self.generate_reason(test_case.actual_output) | ||
self.success = self.score >= self.threshold | ||
self.verbose_logs = construct_verbose_logs( | ||
self, | ||
steps=[ | ||
f"LLM outputed Json:\n{test_case.actual_output}", | ||
# f"Expected Json Schema:\n{json.dumps(self.expected_schema.model_json_schema(), indent=4)}", | ||
f"Score: {self.score}\nReason: {self.reason}", | ||
], | ||
) | ||
|
||
return self.score | ||
|
||
async def a_measure( | ||
self, | ||
test_case: Union[LLMTestCase, ConversationalTestCase], | ||
_show_indicator: bool = True, | ||
) -> float: | ||
if isinstance(test_case, ConversationalTestCase): | ||
test_case = test_case.turns[0] | ||
check_llm_test_case_params(test_case, required_params, self) | ||
|
||
self.evaluation_cost = 0 if self.using_native_model else None | ||
with metric_progress_indicator( | ||
self, async_mode=True, _show_indicator=_show_indicator | ||
): | ||
valid_json = True | ||
try: | ||
self.expected_schema.model_validate_json( | ||
test_case.actual_output | ||
) | ||
except ValidationError as e: | ||
valid_json = False | ||
|
||
self.score = 1 if valid_json else 0 | ||
self.reason = await self.a_generate_reason(test_case.actual_output) | ||
self.success = self.score >= self.threshold | ||
self.verbose_logs = construct_verbose_logs( | ||
self, | ||
steps=[ | ||
f"LLM outputed Json:\n{test_case.actual_output}", | ||
# f"Expected Json Schema:\n{json.dumps(self.expected_schema.model_json_schema(), indent=4)}", | ||
f"Score: {self.score}\nReason: {self.reason}", | ||
], | ||
) | ||
|
||
return self.score | ||
|
||
async def a_generate_reason(self, actual_output: str) -> str: | ||
if self.include_reason is False: | ||
return None | ||
|
||
is_valid_json = self.score == 1 | ||
if is_valid_json: | ||
return DEFAULT_CORRERCT_REASON | ||
|
||
prompt: dict = JsonCorrectnessTemplate.generate_reason( | ||
actual_output=actual_output, | ||
expected_schema=json.dumps( | ||
self.expected_schema.model_json_schema(), indent=4 | ||
), | ||
is_valid_json=is_valid_json, | ||
) | ||
|
||
if self.using_native_model: | ||
res, cost = await self.model.a_generate(prompt) | ||
self.evaluation_cost += cost | ||
data = trimAndLoadJson(res, self) | ||
return data["reason"] | ||
else: | ||
try: | ||
res: Reason = await self.model.a_generate(prompt, schema=Reason) | ||
return res.reason | ||
except TypeError: | ||
res = await self.model.a_generate(prompt) | ||
data = trimAndLoadJson(res, self) | ||
return data["reason"] | ||
|
||
def generate_reason(self, actual_output: str) -> str: | ||
if self.include_reason is False: | ||
return None | ||
|
||
is_valid_json = self.score == 1 | ||
if is_valid_json: | ||
return DEFAULT_CORRERCT_REASON | ||
|
||
prompt: dict = JsonCorrectnessTemplate.generate_reason( | ||
actual_output=actual_output, | ||
expected_schema=json.dumps( | ||
self.expected_schema.model_json_schema(), indent=4 | ||
), | ||
is_valid_json=is_valid_json, | ||
) | ||
|
||
if self.using_native_model: | ||
res, cost = self.model.generate(prompt) | ||
self.evaluation_cost += cost | ||
data = trimAndLoadJson(res, self) | ||
return data["reason"] | ||
else: | ||
try: | ||
res: Reason = self.model.generate(prompt, schema=Reason) | ||
return res.reason | ||
except TypeError: | ||
res = self.model.generate(prompt) | ||
data = trimAndLoadJson(res, self) | ||
return data["reason"] | ||
|
||
def is_successful(self) -> bool: | ||
if self.error is not None: | ||
self.success = False | ||
else: | ||
try: | ||
self.success = self.score >= self.threshold | ||
except: | ||
self.success = False | ||
return self.success | ||
|
||
@property | ||
def __name__(self): | ||
return "Json Correctness" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class Reason(BaseModel): | ||
reason: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Optional | ||
|
||
|
||
class JsonCorrectnessTemplate: | ||
@staticmethod | ||
def generate_reason( | ||
actual_output: str, expected_schema: str, is_valid_json: bool | ||
): | ||
return f"""Based on the given generated json, generated by an LLM, and a boolean stating whether it is a valid JSON based on the expected json schema, give a reason why it is OR is not a valid Json. | ||
** | ||
IMPORTANT: Please make sure to only return in JSON format, with the 'reason' key providing the reason. | ||
Example JSON: | ||
{{ | ||
"reason": "The generated Json is <is_valid_json> because <your_reason>." | ||
}} | ||
If the json is not a valid one, your reason MUST compare `Expected Json Schema` and `Generated Json` in your reason. Keep it SHORT and CONCISE while being very FACTUAL and ACTIONABLE. | ||
** | ||
Generated Json: | ||
{actual_output} | ||
Expected Json Schema: | ||
{expected_schema} | ||
Is Valid Json? | ||
{is_valid_json} | ||
JSON: | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters