From a353547809878030ffea2f3999dd99974b7ab95e Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Fri, 8 Nov 2024 12:58:15 +0800 Subject: [PATCH] feat: added new mixin to modify partial parsing behaviour --- instructor/dsl/partial.py | 20 ++++- tests/llm/test_openai/test_stream.py | 124 ++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 4 deletions(-) diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index e869dbdac..0f8ec798e 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -34,6 +34,10 @@ class MakeFieldsOptional: pass +class LiteralPartialMixin: + pass + + def _make_field_optional( field: FieldInfo, ) -> tuple[Any, FieldInfo]: @@ -127,9 +131,14 @@ def model_from_chunks( ) -> Generator[T_Model, None, None]: potential_object = "" partial_model = cls.get_partial_model() + partial_mode = ( + "on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings" + ) for chunk in json_chunks: potential_object += chunk - obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on") + obj = from_json( + (potential_object.strip() or "{}").encode(), partial_mode=partial_mode + ) obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj @@ -139,9 +148,14 @@ async def model_from_chunks_async( ) -> AsyncGenerator[T_Model, None]: potential_object = "" partial_model = cls.get_partial_model() + partial_mode = ( + "on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings" + ) async for chunk in json_chunks: potential_object += chunk - obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on") + obj = from_json( + (potential_object.strip() or "{}").encode(), partial_mode=partial_mode + ) obj = partial_model.model_validate(obj, strict=None, **kwargs) yield obj @@ -163,7 +177,7 @@ def extract_json( import json resp = chunk.candidates[0].content.parts[0].function_call - resp_dict = type(resp).to_dict(resp) # type:ignore + resp_dict = type(resp).to_dict(resp) # type:ignore if "args" in resp_dict: yield json.dumps(resp_dict["args"]) elif chunk.choices: diff --git a/tests/llm/test_openai/test_stream.py b/tests/llm/test_openai/test_stream.py index 4110cc2c4..9082991b1 100644 --- a/tests/llm/test_openai/test_stream.py +++ b/tests/llm/test_openai/test_stream.py @@ -3,7 +3,7 @@ from pydantic import BaseModel import pytest import instructor -from instructor.dsl.partial import Partial +from instructor.dsl.partial import Partial, LiteralPartialMixin from .util import models, modes @@ -81,3 +81,125 @@ async def test_partial_model_async(model, mode, aclient): ) async for m in model: assert isinstance(m, UserExtract) + + +@pytest.mark.parametrize("model,mode", product(models, modes)) +def test_literal_partial_mixin(model, mode, client): + # Test with LiteralPartialMixin + class UserWithMixin(BaseModel, LiteralPartialMixin): + name: str + age: int + + client = instructor.patch(client, mode=mode) + resp = client.chat.completions.create( + model=model, + response_model=Partial[UserWithMixin], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + + changes = 0 + last_name = None + last_age = None + for m in resp: + assert isinstance(m, UserWithMixin) + if m.name != last_name: + last_name = m.name + changes += 1 + if m.age != last_age: + last_age = m.age + changes += 1 + + assert changes == 2 # Ensure we got at least one field update + + class UserWithoutMixin(BaseModel): + name: str + age: int + + resp = client.chat.completions.create( + model=model, + response_model=Partial[UserWithoutMixin], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + + changes = 0 + last_name = None + last_age = None + for m in resp: + assert isinstance(m, UserWithoutMixin) + if m.name != last_name: + last_name = m.name + changes += 1 + if m.age != last_age: + last_age = m.age + changes += 1 + + assert changes > 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize("model,mode", product(models, modes)) + async def test_literal_partial_mixin_async(model, mode, client): + # Test with LiteralPartialMixin + class UserWithMixin(BaseModel, LiteralPartialMixin): + name: str + age: int + + client = instructor.patch(client, mode=mode) + resp = await client.chat.completions.create( + model=model, + response_model=Partial[UserWithMixin], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + + changes = 0 + last_name = None + last_age = None + async for m in resp: + assert isinstance(m, UserWithMixin) + if m.name != last_name: + last_name = m.name + changes += 1 + if m.age != last_age: + last_age = m.age + changes += 1 + + assert changes == 2 # Ensure we got at least one field update + + class UserWithoutMixin(BaseModel): + name: str + age: int + + resp = await client.chat.completions.create( + model=model, + response_model=Partial[UserWithoutMixin], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + + changes = 0 + last_name = None + last_age = None + async for m in resp: + assert isinstance(m, UserWithoutMixin) + if m.name != last_name: + last_name = m.name + changes += 1 + if m.age != last_age: + last_age = m.age + changes += 1 + + assert changes > 3