Skip to content

Commit

Permalink
refacto: Use native OpenAI API JSON validation
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jun 15, 2024
1 parent 5da1201 commit 762cc9d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 1 addition & 3 deletions function_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,16 +405,14 @@ async def page_to_fact(input: BlobClientTrigger) -> None:
facts: list[FactModel] = []
for _ in range(CONFIG.features.fact_iterations): # We will generate facts 10 times
def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[FactedLlmModel]]:
if not req:
return False, "Empty response", None
req = req.strip().strip("```json\n").strip("\n```").strip()
try:
return True, None, FactedLlmModel.model_validate_json(req)
except ValidationError as e:
return False, str(e), None
facted_llm_model = await llm_client.generate(
res_object=FactedLlmModel,
temperature=1, # We want creative answers
validate_json=True,
validation_callback=_validate,
prompt=f"""
Assistant is an expert data analyst with 20 years of experience.
Expand Down
3 changes: 2 additions & 1 deletion persistence/illm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ async def generate(
prompt: str,
res_object: type[T],
validation_callback: Callable[[Optional[str]], tuple[bool, Optional[str], Optional[T]]],
temperature: float = 0,
max_tokens: Optional[int] = None,
temperature: float = 0,
validate_json: bool = False,
_previous_result: Optional[str] = None,
_retries_remaining: Optional[int] = None,
_validation_error: Optional[str] = None,
Expand Down
9 changes: 8 additions & 1 deletion persistence/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ async def generate(
prompt: str,
res_object: type[T],
validation_callback: Callable[[Optional[str]], tuple[bool, Optional[str], Optional[T]]],
temperature: float = 0,
max_tokens: Optional[int] = None,
temperature: float = 0,
validate_json: bool = False,
_previous_result: Optional[str] = None,
_retries_remaining: Optional[int] = None,
_validation_error: Optional[str] = None,
Expand Down Expand Up @@ -59,13 +60,18 @@ async def generate(
)
)

extra = {}
if validate_json:
extra["response_format"] = {"type": "json_object"}

# Generate
client = self._use_client()
res = await client.chat.completions.create(
max_tokens=max_tokens,
messages=messages,
model=self._config.model,
temperature=temperature,
**extra,
)
res_content = res.choices[0].message.content # type: ignore

Expand All @@ -81,6 +87,7 @@ async def generate(
prompt=prompt,
res_object=res_object,
temperature=temperature,
validate_json=validate_json,
validation_callback=validation_callback,
_previous_result=res_content,
_retries_remaining=_retries_remaining - 1,
Expand Down

0 comments on commit 762cc9d

Please sign in to comment.