-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
(Async)StreamedResponse
for multi-part responses (#383)
* Add disable_parallel_tool_use to Anthropic tool_choice arg * Update cassettes for disable_parallel_tool_use arg * uv add --dev pydeps * Add (Async)StreamedResponse type * Implement StreamedResponse for AnthropicChatModel * Add make dep-diagram * Move parsing logic to _parsing. Use in openai, litellm models * Add TODOs for exceptions and parsing * Use contains_string_type in MistralChatModel * Fix pyright issue with Mistral staticmethod * Add docs for StreamedResponse * Handle StreamedResponse in message_to_openai_message * Load .env env vars first for testing * Add (a)consume. Fix agroupby to consume unconsumed group * Fix non-consumed streamed output groups. Test openai StreamedResponse * Add make test-vcr-once * Add tests for StreamedResponse with AnthropicChatModel * Fix mypy errors * Add missing comment re finishing OutputStream group
- Loading branch information
1 parent
c7a2fe3
commit 0cc46d9
Showing
28 changed files
with
1,320 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Functions for parsing and checking return types.""" | ||
|
||
from collections.abc import Iterable | ||
|
||
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse | ||
from magentic.function_call import AsyncParallelFunctionCall, ParallelFunctionCall | ||
from magentic.streaming import AsyncStreamedStr, StreamedStr | ||
from magentic.typing import is_any_origin_subclass | ||
|
||
|
||
def contains_string_type(types: Iterable[type]) -> bool: | ||
return is_any_origin_subclass( | ||
types, | ||
(str, StreamedStr, AsyncStreamedStr, StreamedResponse, AsyncStreamedResponse), | ||
) | ||
|
||
|
||
def contains_parallel_function_call_type(types: Iterable[type]) -> bool: | ||
return is_any_origin_subclass( | ||
types, | ||
( | ||
ParallelFunctionCall, | ||
AsyncParallelFunctionCall, | ||
StreamedResponse, | ||
AsyncStreamedResponse, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator | ||
from typing import Any | ||
|
||
from magentic.function_call import FunctionCall | ||
from magentic.streaming import ( | ||
AsyncStreamedStr, | ||
CachedAsyncIterable, | ||
CachedIterable, | ||
StreamedStr, | ||
) | ||
|
||
|
||
class StreamedResponse: | ||
"""A streamed LLM response consisting of text output and tool calls. | ||
This is an iterable of StreamedStr and FunctionCall instances. | ||
Examples | ||
-------- | ||
>>> from magentic import prompt, StreamedResponse, StreamedStr, FunctionCall | ||
>>> | ||
>>> def get_weather(city: str) -> str: | ||
>>> return f"The weather in {city} is 20°C." | ||
>>> | ||
>>> @prompt( | ||
>>> "Say hello, then get the weather for: {cities}", | ||
>>> functions=[get_weather], | ||
>>> ) | ||
>>> def describe_weather(cities: list[str]) -> StreamedResponse: ... | ||
>>> | ||
>>> response = describe_weather(["Cape Town", "San Francisco"]) | ||
>>> | ||
>>> for item in response: | ||
>>> if isinstance(item, StreamedStr): | ||
>>> for chunk in item: | ||
>>> print(chunk, sep="", end="") | ||
>>> print() | ||
>>> if isinstance(item, FunctionCall): | ||
>>> print(item) | ||
>>> print(item()) | ||
Hello! I'll get the weather for Cape Town and San Francisco for you. | ||
FunctionCall(<function get_weather at 0x1109825c0>, 'Cape Town') | ||
The weather in Cape Town is 20°C. | ||
FunctionCall(<function get_weather at 0x1109825c0>, 'San Francisco') | ||
The weather in San Francisco is 20°C. | ||
""" | ||
|
||
def __init__(self, stream: Iterable[StreamedStr | FunctionCall[Any]]): | ||
self._stream = CachedIterable(stream) | ||
|
||
def __iter__(self) -> Iterator[StreamedStr | FunctionCall[Any]]: | ||
yield from self._stream | ||
|
||
|
||
class AsyncStreamedResponse: | ||
"""Async version of `StreamedResponse`.""" | ||
|
||
def __init__(self, stream: AsyncIterable[AsyncStreamedStr | FunctionCall[Any]]): | ||
self._stream = CachedAsyncIterable(stream) | ||
|
||
async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall[Any]]: | ||
async for item in self._stream: | ||
yield item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.