diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index f459c8a4..0f83bdd0 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -7,7 +7,12 @@ from pydantic import ValidationError from magentic import FunctionResultMessage -from magentic.chat_model.base import ChatModel, StructuredOutputError +from magentic.chat_model.base import ( + ChatModel, + StructuredOutputError, + avalidate_str_content, + validate_str_content, +) from magentic.chat_model.function_schema import ( AsyncFunctionSchema, BaseFunctionSchema, @@ -335,17 +340,12 @@ def complete( last_content = response.content[-1] if last_content.type == "text": - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - f" {response.model_dump_json()}" - ) - raise StructuredOutputError(msg) - streamed_str = StreamedStr(last_content.text) - if streamed_str_in_output_types: - return AssistantMessage(streamed_str) # type: ignore[return-value] - return AssistantMessage(str(streamed_str)) + str_content = validate_str_content( + StreamedStr(last_content.text), + allow_string_output=allow_string_output, + streamed=streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] if last_content.type == "tool_use": try: @@ -434,17 +434,12 @@ async def acomplete( last_content = response.content[-1] if last_content.type == "text": - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - f" {response.model_dump_json()}" - ) - raise StructuredOutputError(msg) - streamed_str = AsyncStreamedStr(async_iter(last_content.text)) - if async_streamed_str_in_output_types: - return AssistantMessage(streamed_str) # type: ignore[return-value] - return AssistantMessage(str(streamed_str)) + str_content = await avalidate_str_content( + AsyncStreamedStr(async_iter(last_content.text)), + allow_string_output=allow_string_output, + streamed=async_streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] if last_content.type == "tool_use": try: diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 8a5ca9a0..c2c7de9e 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -8,6 +8,7 @@ AssistantMessage, Message, ) +from magentic.streaming import AsyncStreamedStr, StreamedStr R = TypeVar("R") @@ -20,6 +21,41 @@ class StructuredOutputError(Exception): """Raised when the LLM output could not be parsed.""" +_STRING_NOT_EXPECTED_ERROR_MESSAGE = ( + "String was returned by model but not expected. You may need to update" + " your prompt to encourage the model to return a specific type." + " Model output: {model_output!r}" +) + + +def validate_str_content( + streamed_str: StreamedStr, *, allow_string_output: bool, streamed: bool +) -> StreamedStr | str: + """Raise error if string output not expected. Otherwise return correct string type.""" + if not allow_string_output: + msg = _STRING_NOT_EXPECTED_ERROR_MESSAGE.format( + model_output=streamed_str.truncate(100) + ) + raise StructuredOutputError(msg) + if streamed: + return streamed_str + return str(streamed_str) + + +async def avalidate_str_content( + async_streamed_str: AsyncStreamedStr, *, allow_string_output: bool, streamed: bool +) -> AsyncStreamedStr | str: + """Async version of `validate_str_content`.""" + if not allow_string_output: + msg = _STRING_NOT_EXPECTED_ERROR_MESSAGE.format( + model_output=await async_streamed_str.truncate(100) + ) + raise StructuredOutputError(msg) + if streamed: + return async_streamed_str + return await async_streamed_str.to_string() + + class ChatModel(ABC): """An LLM chat model.""" diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 2a2bf435..fa5e7fb9 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -4,7 +4,12 @@ from pydantic import ValidationError -from magentic.chat_model.base import ChatModel, StructuredOutputError +from magentic.chat_model.base import ( + ChatModel, + StructuredOutputError, + avalidate_str_content, + validate_str_content, +) from magentic.chat_model.function_schema import ( FunctionCallFunctionSchema, async_function_schema_for_type, @@ -178,20 +183,17 @@ def complete( raise StructuredOutputError(msg) from e if first_chunk.choices[0].delta.content is not None: - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - ) - raise StructuredOutputError(msg) streamed_str = StreamedStr( chunk.choices[0].delta.get("content", None) for chunk in response if chunk.choices[0].delta.get("content", None) is not None ) - if streamed_str_in_output_types: - return AssistantMessage(streamed_str) # type: ignore[return-value] - return AssistantMessage(str(streamed_str)) + str_content = validate_str_content( + streamed_str, + allow_string_output=allow_string_output, + streamed=streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" raise ValueError(msg) @@ -293,20 +295,17 @@ async def acomplete( raise StructuredOutputError(msg) from e if first_chunk.choices[0].delta.content is not None: - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - ) - raise StructuredOutputError(msg) async_streamed_str = AsyncStreamedStr( chunk.choices[0].delta.get("content", None) async for chunk in response if chunk.choices[0].delta.get("content", None) is not None ) - if async_streamed_str_in_output_types: - return AssistantMessage(async_streamed_str) # type: ignore[return-value] - return AssistantMessage(await async_streamed_str.to_string()) + str_content = await avalidate_str_content( + async_streamed_str, + allow_string_output=allow_string_output, + streamed=async_streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" raise ValueError(msg) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 27b08d9a..95bda609 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -14,7 +14,12 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from pydantic import ValidationError -from magentic.chat_model.base import ChatModel, StructuredOutputError +from magentic.chat_model.base import ( + ChatModel, + StructuredOutputError, + avalidate_str_content, + validate_str_content, +) from magentic.chat_model.function_schema import ( AsyncFunctionSchema, BaseFunctionSchema, @@ -474,20 +479,17 @@ def complete( response = chain([first_chunk], response) if first_chunk.choices[0].delta.content: - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - ) - raise StructuredOutputError(msg) streamed_str = StreamedStr( chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None ) - if streamed_str_in_output_types: - return AssistantMessage(streamed_str) # type: ignore[return-value] - return AssistantMessage(str(streamed_str)) + str_content = validate_str_content( + streamed_str, + allow_string_output=allow_string_output, + streamed=streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] if first_chunk.choices[0].delta.tool_calls: try: @@ -586,20 +588,17 @@ async def acomplete( response = achain(async_iter([first_chunk]), response) if first_chunk.choices[0].delta.content: - if not allow_string_output: - msg = ( - "String was returned by model but not expected. You may need to update" - " your prompt to encourage the model to return a specific type." - ) - raise StructuredOutputError(msg) async_streamed_str = AsyncStreamedStr( chunk.choices[0].delta.content async for chunk in response if chunk.choices[0].delta.content is not None ) - if async_streamed_str_in_output_types: - return AssistantMessage(async_streamed_str) # type: ignore[return-value] - return AssistantMessage(await async_streamed_str.to_string()) + str_content = await avalidate_str_content( + async_streamed_str, + allow_string_output=allow_string_output, + streamed=async_streamed_str_in_output_types, + ) + return AssistantMessage(str_content) # type: ignore[return-value] if first_chunk.choices[0].delta.tool_calls: try: diff --git a/src/magentic/streaming.py b/src/magentic/streaming.py index 6ef95db6..809dbe0f 100644 --- a/src/magentic/streaming.py +++ b/src/magentic/streaming.py @@ -1,4 +1,5 @@ import asyncio +import textwrap from collections.abc import AsyncIterable, Iterable from dataclasses import dataclass from itertools import chain, dropwhile @@ -199,6 +200,17 @@ def to_string(self) -> str: """Convert the streamed string to a string.""" return str(self) + def truncate(self, length: int) -> str: + """Truncate the streamed string to the specified length.""" + chunks = [] + current_length = 0 + for chunk in self._chunks: + chunks.append(chunk) + current_length += len(chunk) + if current_length > length: + break + return textwrap.shorten("".join(chunks), width=length) + class AsyncStreamedStr(AsyncIterable[str]): """Async version of `StreamedStr`.""" @@ -213,3 +225,14 @@ async def __aiter__(self) -> AsyncIterator[str]: async def to_string(self) -> str: """Convert the streamed string to a string.""" return "".join([item async for item in self]) + + async def truncate(self, length: int) -> str: + """Truncate the streamed string to the specified length.""" + chunks = [] + current_length = 0 + async for chunk in self._chunks: + chunks.append(chunk) + current_length += len(chunk) + if current_length > length: + break + return textwrap.shorten("".join(chunks), width=length) diff --git a/tests/chat_model/test_litellm_chat_model.py b/tests/chat_model/test_litellm_chat_model.py index d755df9d..d4bf8b70 100644 --- a/tests/chat_model/test_litellm_chat_model.py +++ b/tests/chat_model/test_litellm_chat_model.py @@ -50,7 +50,9 @@ def test_litellm_chat_model_metadata(litellm_success_callback_calls): chat_model = LitellmChatModel("gpt-3.5-turbo", metadata={"foo": "bar"}) assert chat_model.metadata == {"foo": "bar"} chat_model.complete(messages=[UserMessage("Say hello!")]) - callback_call = litellm_success_callback_calls[0] + # There are multiple callback calls due to streaming + # Take the last one because the first is occasionally from another test + callback_call = litellm_success_callback_calls[-1] assert callback_call["kwargs"]["litellm_params"]["metadata"] == {"foo": "bar"} diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 0c07694e..ff5a1123 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -124,6 +124,12 @@ def test_streamed_str_str(): assert str(streamed_str) == "Hello World" +def test_streamed_str_truncate(): + streamed_str = StreamedStr(["First", " Second", " Third"]) + assert streamed_str.truncate(length=12) == "First [...]" + assert streamed_str.truncate(length=99) == "First Second Third" + + @pytest.mark.asyncio async def test_async_streamed_str_iter(): aiter_chunks = async_iter(["Hello", " World"]) @@ -137,3 +143,10 @@ async def test_async_streamed_str_iter(): async def test_async_streamed_str_to_string(): async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"])) assert await async_streamed_str.to_string() == "Hello World" + + +@pytest.mark.asyncio +async def test_async_streamed_str_truncate(): + async_streamed_str = AsyncStreamedStr(async_iter(["First", " Second", " Third"])) + assert await async_streamed_str.truncate(length=12) == "First [...]" + assert await async_streamed_str.truncate(length=99) == "First Second Third"