Skip to content

Commit

Permalink
Merge branch 'main' into fix-cerebras
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk authored Nov 14, 2024
2 parents 6620dfa + 78a1926 commit e832bc2
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 6 deletions.
17 changes: 17 additions & 0 deletions docs/concepts/partial.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ description: Learn to utilize field-level streaming with Instructor and OpenAI f

# Streaming Partial Responses

!!! info "Literal"

If the data structure you're using has literal values, you need to make sure to import the `PartialLiteralMixin` mixin.

```python
from instructor.dsl.partial import PartialLiteralMixin

class User(BaseModel, PartialLiteralMixin):
name: str
age: int
category: Literal["admin", "user", "guest"]

// The rest of your code below
```

This is because `jiter` throws an error otherwise if it encounters a incomplete Literal value while it's being streamed in

Field level streaming provides incremental snapshots of the current state of the response model that are immediately useable. This approach is particularly relevant in contexts like rendering UI components.

Instructor supports this pattern by making use of `create_partial`. This lets us dynamically create a new class that treats all of the original model's fields as `Optional`.
Expand Down
14 changes: 12 additions & 2 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class MakeFieldsOptional:
pass


class PartialLiteralMixin:
pass


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand Down Expand Up @@ -127,10 +131,13 @@ def model_from_chunks(
) -> Generator[T_Model, None, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode="on"
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
Expand All @@ -141,10 +148,13 @@ async def model_from_chunks_async(
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode="on"
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
Expand Down
6 changes: 3 additions & 3 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
from openai import OpenAI, AsyncOpenAI
Expand Down Expand Up @@ -116,7 +116,7 @@ async def async_generator():


def test_summary_extraction():
class Summary(BaseModel):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = OpenAI()
Expand All @@ -143,7 +143,7 @@ class Summary(BaseModel):

@pytest.mark.asyncio
async def test_summary_extraction_async():
class Summary(BaseModel):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = AsyncOpenAI()
Expand Down
123 changes: 122 additions & 1 deletion tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, PartialLiteralMixin

from .util import models, modes

Expand Down Expand Up @@ -81,3 +81,124 @@ async def test_partial_model_async(model, mode, aclient):
)
async for m in model:
assert isinstance(m, UserExtract)


@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3


@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, aclient):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

client = instructor.patch(aclient, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

0 comments on commit e832bc2

Please sign in to comment.