Skip to content

Commit

Permalink
feat: added new mixin to modify partial parsing behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk committed Nov 8, 2024
1 parent a54ea2d commit a353547
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 4 deletions.
20 changes: 17 additions & 3 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class MakeFieldsOptional:
pass


class LiteralPartialMixin:
pass


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand Down Expand Up @@ -127,9 +131,14 @@ def model_from_chunks(
) -> Generator[T_Model, None, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on")
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -139,9 +148,14 @@ async def model_from_chunks_async(
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on")
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -163,7 +177,7 @@ def extract_json(
import json

resp = chunk.candidates[0].content.parts[0].function_call
resp_dict = type(resp).to_dict(resp) # type:ignore
resp_dict = type(resp).to_dict(resp) # type:ignore
if "args" in resp_dict:
yield json.dumps(resp_dict["args"])
elif chunk.choices:
Expand Down
124 changes: 123 additions & 1 deletion tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, LiteralPartialMixin

from .util import models, modes

Expand Down Expand Up @@ -81,3 +81,125 @@ async def test_partial_model_async(model, mode, aclient):
)
async for m in model:
assert isinstance(m, UserExtract)


@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

0 comments on commit a353547

Please sign in to comment.