Skip to content

Commit

Permalink
Merge pull request #7 from 0xqt/retries-and-simple-validation
Browse files Browse the repository at this point in the history
Adding evals for retries/re-asking and simple validation
  • Loading branch information
jxnl authored Apr 18, 2024
2 parents b716af8 + c1ed7e2 commit ce53701
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/test_retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from itertools import product
from pydantic import AfterValidator, BaseModel, Field
from typing import Annotated
from util import clients


def uppercase_validator(v):
if v.islower():
raise ValueError("Name must be ALL CAPS")
return v


class UserDetail(BaseModel):
name: Annotated[str, AfterValidator(uppercase_validator)] = Field(
..., description="The name of the user"
)
age: int


data = [
(
"Extract `jason is 12`",
("JASON", 12),
),
(
"Extract `danny is 125 years old`",
("DANNY", 125),
),
(
"Extract `DONALD is a 45 year old man`",
("DONALD", 45),
),
]


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("client, data", product(clients, data))
async def test_retries(client, data):
query, expected = data
response = await client.create(
response_model=UserDetail,
messages=[
{"role": "user", "content": query},
],
max_retries=3,
)
assert response.name == expected[0]
assert response.age == expected[1]
55 changes: 55 additions & 0 deletions tests/test_simple_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from itertools import product
from typing import Annotated

from openai import OpenAI
from pydantic import BaseModel, BeforeValidator, ValidationError

import instructor
from instructor import llm_validator
from util import clients

sync_client = instructor.from_openai(OpenAI())


class QuestionAnswerNoEvil(BaseModel):
question: str
answer: Annotated[
str,
BeforeValidator(
llm_validator(
"don't say objectionable or sinful things",
client=sync_client,
)
),
]


data = [
(
"What is the meaning of life?",
"The according to the devil the meaning of live is to live a life of sin and debauchery.",
),
]


@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize("client, data", product(clients, data))
async def test_simple_validation(client, data):
question, context = data

with pytest.raises(ValidationError):
resp = await client.create(
messages=[
{
"role": "system",
"content": "You are a system that answers questions based on the context. answer exactly what the question asks using the context.",
},
{
"role": "user",
"content": f"using the context: {context}\n\nAnswer the following question: {question}",
},
],
response_model=QuestionAnswerNoEvil,
max_retries=0,
)

0 comments on commit ce53701

Please sign in to comment.