Skip to content

Commit

Permalink
Add (Async)StreamedResponse for multi-part responses (#383)
Browse files Browse the repository at this point in the history
* Add disable_parallel_tool_use to Anthropic tool_choice arg

* Update cassettes for disable_parallel_tool_use arg

* uv add --dev pydeps

* Add (Async)StreamedResponse type

* Implement StreamedResponse for AnthropicChatModel

* Add make dep-diagram

* Move parsing logic to _parsing. Use in openai, litellm models

* Add TODOs for exceptions and parsing

* Use contains_string_type in MistralChatModel

* Fix pyright issue with Mistral staticmethod

* Add docs for StreamedResponse

* Handle StreamedResponse in message_to_openai_message

* Load .env env vars first for testing

* Add (a)consume. Fix agroupby to consume unconsumed group

* Fix non-consumed streamed output groups. Test openai StreamedResponse

* Add make test-vcr-once

* Add tests for StreamedResponse with AnthropicChatModel

* Fix mypy errors

* Add missing comment re finishing OutputStream group
  • Loading branch information
jackmpcollins authored Nov 30, 2024
1 parent c7a2fe3 commit 0cc46d9
Show file tree
Hide file tree
Showing 28 changed files with 1,320 additions and 141 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# https://github.com/github/gitignore/blob/4488915eec0b3a45b5c63ead28f286819c0917de/Python.gitignore
## Custom

# Dependency diagram
magentic.svg

## Python https://github.com/github/gitignore/blob/4488915eec0b3a45b5c63ead28f286819c0917de/Python.gitignore

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ testcov: test # Run tests and generate a coverage report
@echo "building coverage html"
uv run coverage html --show-contexts

.PHONY: test-vcr-once
test-vcr-once: # Run the tests and record new VCR cassettes
uv run pytest -vv --record-mode=once

.PHONY: test-fix-vcr
test-fix-vcr: # Run the last failed tests and rewrite the VCR cassettes
uv run pytest -vv --last-failed --last-failed-no-failures=none --record-mode=rewrite
Expand All @@ -55,5 +59,10 @@ docs: # Build the documentation
docs-serve: # Build and serve the documentation
uv run mkdocs serve

.PHONY: dep-diagram
dep-diagram: # Generate a dependency diagram
uv run pydeps src/magentic --no-show --only "magentic." --rmprefix "magentic." -x "magentic.logger" --exclude-exact "magentic.chat_model"
open -a Arc magentic.svg

.PHONY: all
all: format lint typecheck test
44 changes: 44 additions & 0 deletions docs/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,47 @@ for hero in create_superhero_team("The Food Dudes"):
# 4.03s : name='Captain Carrot' age=35 power='Super strength and agility from eating carrots' enemies=['The Sugar Squad', 'The Greasy Gang']
# 6.05s : name='Ice Cream Girl' age=25 power='Can create ice cream out of thin air' enemies=['The Hot Sauce Squad', 'The Healthy Eaters']
```

## StreamedResponse

Some LLMs have the ability to generate text output and make tool calls in the same response. This allows them to perform chain-of-thought reasoning or provide additional context to the user. In magentic, the `StreamedResponse` (or `AsyncStreamedResponse`) class can be used to request this type of output. This object is an iterable of `StreamedStr` (or `AsyncStreamedStr`) and `FunctionCall` instances.

!!! warning "Consuming StreamedStr"

The StreamedStr object must be iterated over before the next item in the `StreamedResponse` is processed, otherwise the string output will be lost. This is because the `StreamedResponse` and `StreamedStr` share the same underlying generator, so advancing the `StreamedResponse` iterator skips over the `StreamedStr` items. The `StreamedStr` object has internal caching so after iterating over it once the chunks will remain available.

In the example below, we request that the LLM generates a greeting and then calls a function to get the weather for two cities. The `StreamedResponse` object is then iterated over to print the output, and the `StreamedStr` and `FunctionCall` items are processed separately.

```python
from magentic import prompt, FunctionCall, StreamedResponse, StreamedStr


def get_weather(city: str) -> str:
return f"The weather in {city} is 20°C."


@prompt(
"Say hello, then get the weather for: {cities}",
functions=[get_weather],
)
def describe_weather(cities: list[str]) -> StreamedResponse: ...


response = describe_weather(["Cape Town", "San Francisco"])
for item in response:
if isinstance(item, StreamedStr):
for chunk in item:
# print the chunks as they are received
print(chunk, sep="", end="")
print()
if isinstance(item, FunctionCall):
# print the function call, then call it and print the result
print(item)
print(item())

# Hello! I'll get the weather for Cape Town and San Francisco for you.
# FunctionCall(<function get_weather at 0x1109825c0>, 'Cape Town')
# The weather in Cape Town is 20°C.
# FunctionCall(<function get_weather at 0x1109825c0>, 'San Francisco')
# The weather in San Francisco is 20°C.
```
4 changes: 4 additions & 0 deletions docs/structured-outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ print(hero_defeated)

## Chain-of-Thought Prompting

!!! warning "StreamedResponse"

It is now recommended to use `StreamedResponse` for chain-of-thought prompting, as this uses the LLM provider's native chain-of-thought capabilities. See [StreamedResponse](streaming.md#StreamedResponse) for more information.

Using a simple Python type as the return annotation might result in poor results as the LLM has no time to arrange its thoughts before answering. To allow the LLM to work through this "chain of thought" you can instead return a pydantic model with initial fields for explaining the final response.

```python hl_lines="5-9 20"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dev = [
"coverage>=7.6.4",
"pytest-mock>=3.14.0",
"vcrpy>=6.0.2",
"pydeps>=2.0.1",
]
docs = [
"blacken-docs>=1.16.0",
Expand Down
2 changes: 2 additions & 0 deletions src/magentic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._pydantic import ConfigDict as ConfigDict
from ._pydantic import with_config as with_config
from ._streamed_response import AsyncStreamedResponse as AsyncStreamedResponse
from ._streamed_response import StreamedResponse as StreamedResponse
from .chat_model.message import AnyMessage as AnyMessage
from .chat_model.message import AssistantMessage as AssistantMessage
from .chat_model.message import FunctionResultMessage as FunctionResultMessage
Expand Down
27 changes: 27 additions & 0 deletions src/magentic/_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Functions for parsing and checking return types."""

from collections.abc import Iterable

from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.function_call import AsyncParallelFunctionCall, ParallelFunctionCall
from magentic.streaming import AsyncStreamedStr, StreamedStr
from magentic.typing import is_any_origin_subclass


def contains_string_type(types: Iterable[type]) -> bool:
return is_any_origin_subclass(
types,
(str, StreamedStr, AsyncStreamedStr, StreamedResponse, AsyncStreamedResponse),
)


def contains_parallel_function_call_type(types: Iterable[type]) -> bool:
return is_any_origin_subclass(
types,
(
ParallelFunctionCall,
AsyncParallelFunctionCall,
StreamedResponse,
AsyncStreamedResponse,
),
)
63 changes: 63 additions & 0 deletions src/magentic/_streamed_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator
from typing import Any

from magentic.function_call import FunctionCall
from magentic.streaming import (
AsyncStreamedStr,
CachedAsyncIterable,
CachedIterable,
StreamedStr,
)


class StreamedResponse:
"""A streamed LLM response consisting of text output and tool calls.
This is an iterable of StreamedStr and FunctionCall instances.
Examples
--------
>>> from magentic import prompt, StreamedResponse, StreamedStr, FunctionCall
>>>
>>> def get_weather(city: str) -> str:
>>> return f"The weather in {city} is 20°C."
>>>
>>> @prompt(
>>> "Say hello, then get the weather for: {cities}",
>>> functions=[get_weather],
>>> )
>>> def describe_weather(cities: list[str]) -> StreamedResponse: ...
>>>
>>> response = describe_weather(["Cape Town", "San Francisco"])
>>>
>>> for item in response:
>>> if isinstance(item, StreamedStr):
>>> for chunk in item:
>>> print(chunk, sep="", end="")
>>> print()
>>> if isinstance(item, FunctionCall):
>>> print(item)
>>> print(item())
Hello! I'll get the weather for Cape Town and San Francisco for you.
FunctionCall(<function get_weather at 0x1109825c0>, 'Cape Town')
The weather in Cape Town is 20°C.
FunctionCall(<function get_weather at 0x1109825c0>, 'San Francisco')
The weather in San Francisco is 20°C.
"""

def __init__(self, stream: Iterable[StreamedStr | FunctionCall[Any]]):
self._stream = CachedIterable(stream)

def __iter__(self) -> Iterator[StreamedStr | FunctionCall[Any]]:
yield from self._stream


class AsyncStreamedResponse:
"""Async version of `StreamedResponse`."""

def __init__(self, stream: AsyncIterable[AsyncStreamedStr | FunctionCall[Any]]):
self._stream = CachedAsyncIterable(stream)

async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall[Any]]:
async for item in self._stream:
yield item
43 changes: 22 additions & 21 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import filetype

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic.chat_model.base import (
ChatModel,
aparse_stream,
Expand Down Expand Up @@ -48,19 +49,18 @@
ParallelFunctionCall,
_create_unique_id,
)
from magentic.streaming import (
AsyncStreamedStr,
StreamedStr,
)
from magentic.typing import is_any_origin_subclass
from magentic.vision import UserImageMessage

try:
import anthropic
from anthropic.lib.streaming import MessageStreamEvent
from anthropic.lib.streaming._messages import accumulate_event
from anthropic.types import MessageParam, ToolParam
from anthropic.types.message_create_params import ToolChoice
from anthropic.types import (
MessageParam,
ToolChoiceParam,
ToolChoiceToolParam,
ToolParam,
)
except ImportError as error:
msg = "To use AnthropicChatModel you must install the `anthropic` package using `pip install 'magentic[anthropic]'`."
raise ImportError(msg) from error
Expand Down Expand Up @@ -157,6 +157,8 @@ def _(message: AssistantMessage[Any]) -> MessageParam:
],
}

# TODO: Add support for StreamedResponse here

function_schema = function_schema_for_type(type(message.content))
return {
"role": AnthropicMessageRole.ASSISTANT.value,
Expand Down Expand Up @@ -227,7 +229,7 @@ def to_dict(self) -> ToolParam:
"input_schema": self._function_schema.parameters,
}

def as_tool_choice(self) -> ToolChoice:
def as_tool_choice(self, *, disable_parallel_tool_use: bool) -> ToolChoiceToolParam:
return {"type": "tool", "name": self._function_schema.name}


Expand Down Expand Up @@ -353,14 +355,19 @@ def temperature(self) -> float | None:
def _get_tool_choice(
*,
tool_schemas: Sequence[BaseFunctionToolSchema[Any]],
allow_string_output: bool,
) -> ToolChoice | anthropic.NotGiven:
output_types: Iterable[type],
) -> ToolChoiceParam | anthropic.NotGiven:
"""Create the tool choice argument."""
if allow_string_output:
if contains_string_type(output_types):
return anthropic.NOT_GIVEN
disable_parallel_tool_use = not contains_parallel_function_call_type(
output_types
)
if len(tool_schemas) == 1:
return tool_schemas[0].as_tool_choice()
return {"type": "any"}
return tool_schemas[0].as_tool_choice(
disable_parallel_tool_use=disable_parallel_tool_use
)
return {"type": "any", "disable_parallel_tool_use": disable_parallel_tool_use}

@overload
def complete(
Expand Down Expand Up @@ -397,8 +404,6 @@ def complete(
function_schemas = get_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]

allow_string_output = is_any_origin_subclass(output_types, (str, StreamedStr))

system, messages = _extract_system_message(messages)

response: Iterator[MessageStreamEvent] = self._client.messages.stream(
Expand All @@ -412,7 +417,7 @@ def complete(
temperature=_if_given(self.temperature),
tools=[schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN,
tool_choice=self._get_tool_choice(
tool_schemas=tool_schemas, allow_string_output=allow_string_output
tool_schemas=tool_schemas, output_types=output_types
),
).__enter__()
stream = OutputStream(
Expand Down Expand Up @@ -460,10 +465,6 @@ async def acomplete(
function_schemas = get_async_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]

allow_string_output = is_any_origin_subclass(
output_types, (str, AsyncStreamedStr)
)

system, messages = _extract_system_message(messages)

response: AsyncIterator[
Expand All @@ -479,7 +480,7 @@ async def acomplete(
temperature=_if_given(self.temperature),
tools=[schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN,
tool_choice=self._get_tool_choice(
tool_schemas=tool_schemas, allow_string_output=allow_string_output
tool_schemas=tool_schemas, output_types=output_types
),
).__aenter__()
stream = AsyncOutputStream(
Expand Down
13 changes: 12 additions & 1 deletion src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pydantic import ValidationError

from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.chat_model.message import AssistantMessage, Message
from magentic.function_call import (
AsyncParallelFunctionCall,
Expand All @@ -22,6 +23,7 @@
)


# TODO: Export all exceptions from `magentic.exceptions`
# TODO: Parent class with `output_message` attribute ?
class StringNotAllowedError(Exception):
"""Raised when a string is returned by the LLM but not allowed."""
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(self, output_message: Message[Any], tool_call_id: str, tool_name: s
self.tool_call_id = tool_call_id


# TODO: Move this to same file where it is raised
class ToolSchemaParseError(Exception):
"""Raised when the LLM output could not be parsed by the tool schema."""

Expand All @@ -100,21 +103,25 @@ def __init__(
self.validation_error = validation_error


# TODO: Move this into _parsing
# TODO: Make this a stream class with a close method and context management
def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R:
"""Parse and validate the LLM output stream against the allowed output types."""
output_type_origins = [get_origin(type_) or type_ for type_ in output_types]
# TODO: option to error/warn/ignore extra objects
# TODO: warn for degenerate output types ?
obj = next(stream)
# TODO: Add type for mixed StreamedStr and FunctionCalls
if isinstance(obj, StreamedStr):
if StreamedResponse in output_type_origins:
return cast(R, StreamedResponse(chain([obj], stream)))
if StreamedStr in output_type_origins:
return cast(R, obj)
if str in output_type_origins:
return cast(R, str(obj))
raise StringNotAllowedError(obj.truncate(100))
if isinstance(obj, FunctionCall):
if StreamedResponse in output_type_origins:
return cast(R, StreamedResponse(chain([obj], stream)))
if ParallelFunctionCall in output_type_origins:
return cast(R, ParallelFunctionCall(chain([obj], stream)))
if FunctionCall in output_type_origins:
Expand All @@ -133,12 +140,16 @@ async def aparse_stream(
output_type_origins = [get_origin(type_) or type_ for type_ in output_types]
obj = await anext(stream)
if isinstance(obj, AsyncStreamedStr):
if AsyncStreamedResponse in output_type_origins:
return cast(R, AsyncStreamedResponse(achain(async_iter([obj]), stream)))
if AsyncStreamedStr in output_type_origins:
return cast(R, obj)
if str in output_type_origins:
return cast(R, await obj.to_string())
raise StringNotAllowedError(await obj.truncate(100))
if isinstance(obj, FunctionCall):
if AsyncStreamedResponse in output_type_origins:
return cast(R, AsyncStreamedResponse(achain(async_iter([obj]), stream)))
if AsyncParallelFunctionCall in output_type_origins:
return cast(R, AsyncParallelFunctionCall(achain(async_iter([obj]), stream)))
if FunctionCall in output_type_origins:
Expand Down
3 changes: 3 additions & 0 deletions src/magentic/chat_model/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, TypeAdapter, create_model

from magentic._pydantic import ConfigDict, get_pydantic_config, json_schema
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
from magentic.function_call import (
AsyncParallelFunctionCall,
FunctionCall,
Expand Down Expand Up @@ -461,6 +462,8 @@ def serialize_args(self, value: FunctionCall[T]) -> str:
FunctionCall,
ParallelFunctionCall,
AsyncParallelFunctionCall,
StreamedResponse,
AsyncStreamedResponse,
)


Expand Down
Loading

0 comments on commit 0cc46d9

Please sign in to comment.