From 762cc9dae5766648c231e2b781bb81137fa7765f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Sat, 15 Jun 2024 20:34:40 +0200 Subject: [PATCH] refacto: Use native OpenAI API JSON validation --- function_app.py | 4 +--- persistence/illm.py | 3 ++- persistence/openai.py | 9 ++++++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/function_app.py b/function_app.py index f92f453..1ae3e3d 100644 --- a/function_app.py +++ b/function_app.py @@ -405,9 +405,6 @@ async def page_to_fact(input: BlobClientTrigger) -> None: facts: list[FactModel] = [] for _ in range(CONFIG.features.fact_iterations): # We will generate facts 10 times def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[FactedLlmModel]]: - if not req: - return False, "Empty response", None - req = req.strip().strip("```json\n").strip("\n```").strip() try: return True, None, FactedLlmModel.model_validate_json(req) except ValidationError as e: @@ -415,6 +412,7 @@ def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[FactedL facted_llm_model = await llm_client.generate( res_object=FactedLlmModel, temperature=1, # We want creative answers + validate_json=True, validation_callback=_validate, prompt=f""" Assistant is an expert data analyst with 20 years of experience. diff --git a/persistence/illm.py b/persistence/illm.py index 7bd70a5..9a9a9cf 100644 --- a/persistence/illm.py +++ b/persistence/illm.py @@ -12,8 +12,9 @@ async def generate( prompt: str, res_object: type[T], validation_callback: Callable[[Optional[str]], tuple[bool, Optional[str], Optional[T]]], - temperature: float = 0, max_tokens: Optional[int] = None, + temperature: float = 0, + validate_json: bool = False, _previous_result: Optional[str] = None, _retries_remaining: Optional[int] = None, _validation_error: Optional[str] = None, diff --git a/persistence/openai.py b/persistence/openai.py index 5b9c94d..ead1485 100644 --- a/persistence/openai.py +++ b/persistence/openai.py @@ -24,8 +24,9 @@ async def generate( prompt: str, res_object: type[T], validation_callback: Callable[[Optional[str]], tuple[bool, Optional[str], Optional[T]]], - temperature: float = 0, max_tokens: Optional[int] = None, + temperature: float = 0, + validate_json: bool = False, _previous_result: Optional[str] = None, _retries_remaining: Optional[int] = None, _validation_error: Optional[str] = None, @@ -59,6 +60,10 @@ async def generate( ) ) + extra = {} + if validate_json: + extra["response_format"] = {"type": "json_object"} + # Generate client = self._use_client() res = await client.chat.completions.create( @@ -66,6 +71,7 @@ async def generate( messages=messages, model=self._config.model, temperature=temperature, + **extra, ) res_content = res.choices[0].message.content # type: ignore @@ -81,6 +87,7 @@ async def generate( prompt=prompt, res_object=res_object, temperature=temperature, + validate_json=validate_json, validation_callback=validation_callback, _previous_result=res_content, _retries_remaining=_retries_remaining - 1,