diff --git a/docs/concepts/partial.md b/docs/concepts/partial.md index a83cd7870..27237cf2d 100644 --- a/docs/concepts/partial.md +++ b/docs/concepts/partial.md @@ -5,6 +5,23 @@ description: Learn to utilize field-level streaming with Instructor and OpenAI f # Streaming Partial Responses +!!! info "Literal" + + If the data structure you're using has literal values, you need to make sure to import the `PartialStringHandlingMixin` mixin. + + ```python + from instructor.dsl.partial import PartialStringHandlingMixin + + class User(BaseModel, PartialStringHandlingMixin): + name: str + age: int + category: Literal["admin", "user", "guest"] + + // The rest of your code below + ``` + + This is because `jiter` throws an error otherwise if it encounters a incomplete Literal value while it's being streamed in + Field level streaming provides incremental snapshots of the current state of the response model that are immediately useable. This approach is particularly relevant in contexts like rendering UI components. Instructor supports this pattern by making use of `create_partial`. This lets us dynamically create a new class that treats all of the original model's fields as `Optional`. diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 0f8ec798e..7efcc2c2e 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -34,7 +34,7 @@ class MakeFieldsOptional: pass -class LiteralPartialMixin: +class PartialStringHandlingMixin: pass @@ -132,7 +132,7 @@ def model_from_chunks( potential_object = "" partial_model = cls.get_partial_model() partial_mode = ( - "on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings" + "on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings" ) for chunk in json_chunks: potential_object += chunk @@ -149,7 +149,7 @@ async def model_from_chunks_async( potential_object = "" partial_model = cls.get_partial_model() partial_mode = ( - "on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings" + "on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings" ) async for chunk in json_chunks: potential_object += chunk diff --git a/tests/dsl/test_partial.py b/tests/dsl/test_partial.py index b18bfc3e9..5d3208b3e 100644 --- a/tests/dsl/test_partial.py +++ b/tests/dsl/test_partial.py @@ -1,6 +1,6 @@ # type: ignore[all] from pydantic import BaseModel, Field -from instructor.dsl.partial import Partial, LiteralPartialMixin +from instructor.dsl.partial import Partial, PartialStringHandlingMixin import pytest import instructor from openai import OpenAI, AsyncOpenAI @@ -116,7 +116,7 @@ async def async_generator(): def test_summary_extraction(): - class Summary(BaseModel, LiteralPartialMixin): + class Summary(BaseModel, PartialStringHandlingMixin): summary: str = Field(description="A detailed summary") client = OpenAI() @@ -143,7 +143,7 @@ class Summary(BaseModel, LiteralPartialMixin): @pytest.mark.asyncio async def test_summary_extraction_async(): - class Summary(BaseModel, LiteralPartialMixin): + class Summary(BaseModel, PartialStringHandlingMixin): summary: str = Field(description="A detailed summary") client = AsyncOpenAI() diff --git a/tests/llm/test_openai/test_stream.py b/tests/llm/test_openai/test_stream.py index 9082991b1..b4ef47b37 100644 --- a/tests/llm/test_openai/test_stream.py +++ b/tests/llm/test_openai/test_stream.py @@ -3,7 +3,7 @@ from pydantic import BaseModel import pytest import instructor -from instructor.dsl.partial import Partial, LiteralPartialMixin +from instructor.dsl.partial import Partial, PartialStringHandlingMixin from .util import models, modes @@ -85,8 +85,7 @@ async def test_partial_model_async(model, mode, aclient): @pytest.mark.parametrize("model,mode", product(models, modes)) def test_literal_partial_mixin(model, mode, client): - # Test with LiteralPartialMixin - class UserWithMixin(BaseModel, LiteralPartialMixin): + class UserWithMixin(BaseModel, PartialStringHandlingMixin): name: str age: int @@ -146,8 +145,7 @@ class UserWithoutMixin(BaseModel): @pytest.mark.asyncio @pytest.mark.parametrize("model,mode", product(models, modes)) async def test_literal_partial_mixin_async(model, mode, client): - # Test with LiteralPartialMixin - class UserWithMixin(BaseModel, LiteralPartialMixin): + class UserWithMixin(BaseModel, PartialStringHandlingMixin): name: str age: int