-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from 0xqt/retries-and-simple-validation
Adding evals for retries/re-asking and simple validation
- Loading branch information
Showing
2 changed files
with
104 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
from itertools import product | ||
from pydantic import AfterValidator, BaseModel, Field | ||
from typing import Annotated | ||
from util import clients | ||
|
||
|
||
def uppercase_validator(v): | ||
if v.islower(): | ||
raise ValueError("Name must be ALL CAPS") | ||
return v | ||
|
||
|
||
class UserDetail(BaseModel): | ||
name: Annotated[str, AfterValidator(uppercase_validator)] = Field( | ||
..., description="The name of the user" | ||
) | ||
age: int | ||
|
||
|
||
data = [ | ||
( | ||
"Extract `jason is 12`", | ||
("JASON", 12), | ||
), | ||
( | ||
"Extract `danny is 125 years old`", | ||
("DANNY", 125), | ||
), | ||
( | ||
"Extract `DONALD is a 45 year old man`", | ||
("DONALD", 45), | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.asyncio_cooperative | ||
@pytest.mark.parametrize("client, data", product(clients, data)) | ||
async def test_retries(client, data): | ||
query, expected = data | ||
response = await client.create( | ||
response_model=UserDetail, | ||
messages=[ | ||
{"role": "user", "content": query}, | ||
], | ||
max_retries=3, | ||
) | ||
assert response.name == expected[0] | ||
assert response.age == expected[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytest | ||
from itertools import product | ||
from typing import Annotated | ||
|
||
from openai import OpenAI | ||
from pydantic import BaseModel, BeforeValidator, ValidationError | ||
|
||
import instructor | ||
from instructor import llm_validator | ||
from util import clients | ||
|
||
sync_client = instructor.from_openai(OpenAI()) | ||
|
||
|
||
class QuestionAnswerNoEvil(BaseModel): | ||
question: str | ||
answer: Annotated[ | ||
str, | ||
BeforeValidator( | ||
llm_validator( | ||
"don't say objectionable or sinful things", | ||
client=sync_client, | ||
) | ||
), | ||
] | ||
|
||
|
||
data = [ | ||
( | ||
"What is the meaning of life?", | ||
"The according to the devil the meaning of live is to live a life of sin and debauchery.", | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.asyncio_cooperative | ||
@pytest.mark.parametrize("client, data", product(clients, data)) | ||
async def test_simple_validation(client, data): | ||
question, context = data | ||
|
||
with pytest.raises(ValidationError): | ||
resp = await client.create( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are a system that answers questions based on the context. answer exactly what the question asks using the context.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"using the context: {context}\n\nAnswer the following question: {question}", | ||
}, | ||
], | ||
response_model=QuestionAnswerNoEvil, | ||
max_retries=0, | ||
) |