Skip to content

Commit

Permalink
fix: modified the naming of the mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk committed Nov 8, 2024
1 parent 89d62f2 commit b52ec6a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
17 changes: 17 additions & 0 deletions docs/concepts/partial.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ description: Learn to utilize field-level streaming with Instructor and OpenAI f

# Streaming Partial Responses

!!! info "Literal"

If the data structure you're using has literal values, you need to make sure to import the `PartialStringHandlingMixin` mixin.

```python
from instructor.dsl.partial import PartialStringHandlingMixin

class User(BaseModel, PartialStringHandlingMixin):
name: str
age: int
category: Literal["admin", "user", "guest"]

// The rest of your code below
```

This is because `jiter` throws an error otherwise if it encounters a incomplete Literal value while it's being streamed in

Field level streaming provides incremental snapshots of the current state of the response model that are immediately useable. This approach is particularly relevant in contexts like rendering UI components.

Instructor supports this pattern by making use of `create_partial`. This lets us dynamically create a new class that treats all of the original model's fields as `Optional`.
Expand Down
6 changes: 3 additions & 3 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MakeFieldsOptional:
pass


class LiteralPartialMixin:
class PartialStringHandlingMixin:
pass


Expand Down Expand Up @@ -132,7 +132,7 @@ def model_from_chunks(
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
"on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
Expand All @@ -149,7 +149,7 @@ async def model_from_chunks_async(
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
"on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
Expand Down
6 changes: 3 additions & 3 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from instructor.dsl.partial import Partial, LiteralPartialMixin
from instructor.dsl.partial import Partial, PartialStringHandlingMixin
import pytest
import instructor
from openai import OpenAI, AsyncOpenAI
Expand Down Expand Up @@ -116,7 +116,7 @@ async def async_generator():


def test_summary_extraction():
class Summary(BaseModel, LiteralPartialMixin):
class Summary(BaseModel, PartialStringHandlingMixin):
summary: str = Field(description="A detailed summary")

client = OpenAI()
Expand All @@ -143,7 +143,7 @@ class Summary(BaseModel, LiteralPartialMixin):

@pytest.mark.asyncio
async def test_summary_extraction_async():
class Summary(BaseModel, LiteralPartialMixin):
class Summary(BaseModel, PartialStringHandlingMixin):
summary: str = Field(description="A detailed summary")

client = AsyncOpenAI()
Expand Down
8 changes: 3 additions & 5 deletions tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial, LiteralPartialMixin
from instructor.dsl.partial import Partial, PartialStringHandlingMixin

from .util import models, modes

Expand Down Expand Up @@ -85,8 +85,7 @@ async def test_partial_model_async(model, mode, aclient):

@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
class UserWithMixin(BaseModel, PartialStringHandlingMixin):
name: str
age: int

Expand Down Expand Up @@ -146,8 +145,7 @@ class UserWithoutMixin(BaseModel):
@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
class UserWithMixin(BaseModel, PartialStringHandlingMixin):
name: str
age: int

Expand Down

0 comments on commit b52ec6a

Please sign in to comment.