Skip to content

Commit

Permalink
Include generated text in error message for "string not expected" (#195)
Browse files Browse the repository at this point in the history
* Add truncate method to StreamedStr

* Add validate_str_content and use to standardize error message

* Repr model output in error message. Truncate at 100 not 50 chars

* fix: anthropic streamed=True -> variable

* Fix: flaky litellm metadata test
  • Loading branch information
jackmpcollins authored Apr 29, 2024
1 parent cf730e1 commit b3cff4d
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 62 deletions.
41 changes: 18 additions & 23 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from pydantic import ValidationError

from magentic import FunctionResultMessage
from magentic.chat_model.base import ChatModel, StructuredOutputError
from magentic.chat_model.base import (
ChatModel,
StructuredOutputError,
avalidate_str_content,
validate_str_content,
)
from magentic.chat_model.function_schema import (
AsyncFunctionSchema,
BaseFunctionSchema,
Expand Down Expand Up @@ -335,17 +340,12 @@ def complete(
last_content = response.content[-1]

if last_content.type == "text":
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
f" {response.model_dump_json()}"
)
raise StructuredOutputError(msg)
streamed_str = StreamedStr(last_content.text)
if streamed_str_in_output_types:
return AssistantMessage(streamed_str) # type: ignore[return-value]
return AssistantMessage(str(streamed_str))
str_content = validate_str_content(
StreamedStr(last_content.text),
allow_string_output=allow_string_output,
streamed=streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

if last_content.type == "tool_use":
try:
Expand Down Expand Up @@ -434,17 +434,12 @@ async def acomplete(
last_content = response.content[-1]

if last_content.type == "text":
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
f" {response.model_dump_json()}"
)
raise StructuredOutputError(msg)
streamed_str = AsyncStreamedStr(async_iter(last_content.text))
if async_streamed_str_in_output_types:
return AssistantMessage(streamed_str) # type: ignore[return-value]
return AssistantMessage(str(streamed_str))
str_content = await avalidate_str_content(
AsyncStreamedStr(async_iter(last_content.text)),
allow_string_output=allow_string_output,
streamed=async_streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

if last_content.type == "tool_use":
try:
Expand Down
36 changes: 36 additions & 0 deletions src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AssistantMessage,
Message,
)
from magentic.streaming import AsyncStreamedStr, StreamedStr

R = TypeVar("R")

Expand All @@ -20,6 +21,41 @@ class StructuredOutputError(Exception):
"""Raised when the LLM output could not be parsed."""


_STRING_NOT_EXPECTED_ERROR_MESSAGE = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
" Model output: {model_output!r}"
)


def validate_str_content(
streamed_str: StreamedStr, *, allow_string_output: bool, streamed: bool
) -> StreamedStr | str:
"""Raise error if string output not expected. Otherwise return correct string type."""
if not allow_string_output:
msg = _STRING_NOT_EXPECTED_ERROR_MESSAGE.format(
model_output=streamed_str.truncate(100)
)
raise StructuredOutputError(msg)
if streamed:
return streamed_str
return str(streamed_str)


async def avalidate_str_content(
async_streamed_str: AsyncStreamedStr, *, allow_string_output: bool, streamed: bool
) -> AsyncStreamedStr | str:
"""Async version of `validate_str_content`."""
if not allow_string_output:
msg = _STRING_NOT_EXPECTED_ERROR_MESSAGE.format(
model_output=await async_streamed_str.truncate(100)
)
raise StructuredOutputError(msg)
if streamed:
return async_streamed_str
return await async_streamed_str.to_string()


class ChatModel(ABC):
"""An LLM chat model."""

Expand Down
37 changes: 18 additions & 19 deletions src/magentic/chat_model/litellm_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from pydantic import ValidationError

from magentic.chat_model.base import ChatModel, StructuredOutputError
from magentic.chat_model.base import (
ChatModel,
StructuredOutputError,
avalidate_str_content,
validate_str_content,
)
from magentic.chat_model.function_schema import (
FunctionCallFunctionSchema,
async_function_schema_for_type,
Expand Down Expand Up @@ -178,20 +183,17 @@ def complete(
raise StructuredOutputError(msg) from e

if first_chunk.choices[0].delta.content is not None:
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise StructuredOutputError(msg)
streamed_str = StreamedStr(
chunk.choices[0].delta.get("content", None)
for chunk in response
if chunk.choices[0].delta.get("content", None) is not None
)
if streamed_str_in_output_types:
return AssistantMessage(streamed_str) # type: ignore[return-value]
return AssistantMessage(str(streamed_str))
str_content = validate_str_content(
streamed_str,
allow_string_output=allow_string_output,
streamed=streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}"
raise ValueError(msg)
Expand Down Expand Up @@ -293,20 +295,17 @@ async def acomplete(
raise StructuredOutputError(msg) from e

if first_chunk.choices[0].delta.content is not None:
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise StructuredOutputError(msg)
async_streamed_str = AsyncStreamedStr(
chunk.choices[0].delta.get("content", None)
async for chunk in response
if chunk.choices[0].delta.get("content", None) is not None
)
if async_streamed_str_in_output_types:
return AssistantMessage(async_streamed_str) # type: ignore[return-value]
return AssistantMessage(await async_streamed_str.to_string())
str_content = await avalidate_str_content(
async_streamed_str,
allow_string_output=allow_string_output,
streamed=async_streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}"
raise ValueError(msg)
37 changes: 18 additions & 19 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from pydantic import ValidationError

from magentic.chat_model.base import ChatModel, StructuredOutputError
from magentic.chat_model.base import (
ChatModel,
StructuredOutputError,
avalidate_str_content,
validate_str_content,
)
from magentic.chat_model.function_schema import (
AsyncFunctionSchema,
BaseFunctionSchema,
Expand Down Expand Up @@ -474,20 +479,17 @@ def complete(
response = chain([first_chunk], response)

if first_chunk.choices[0].delta.content:
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise StructuredOutputError(msg)
streamed_str = StreamedStr(
chunk.choices[0].delta.content
for chunk in response
if chunk.choices[0].delta.content is not None
)
if streamed_str_in_output_types:
return AssistantMessage(streamed_str) # type: ignore[return-value]
return AssistantMessage(str(streamed_str))
str_content = validate_str_content(
streamed_str,
allow_string_output=allow_string_output,
streamed=streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

if first_chunk.choices[0].delta.tool_calls:
try:
Expand Down Expand Up @@ -586,20 +588,17 @@ async def acomplete(
response = achain(async_iter([first_chunk]), response)

if first_chunk.choices[0].delta.content:
if not allow_string_output:
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise StructuredOutputError(msg)
async_streamed_str = AsyncStreamedStr(
chunk.choices[0].delta.content
async for chunk in response
if chunk.choices[0].delta.content is not None
)
if async_streamed_str_in_output_types:
return AssistantMessage(async_streamed_str) # type: ignore[return-value]
return AssistantMessage(await async_streamed_str.to_string())
str_content = await avalidate_str_content(
async_streamed_str,
allow_string_output=allow_string_output,
streamed=async_streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]

if first_chunk.choices[0].delta.tool_calls:
try:
Expand Down
23 changes: 23 additions & 0 deletions src/magentic/streaming.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import textwrap
from collections.abc import AsyncIterable, Iterable
from dataclasses import dataclass
from itertools import chain, dropwhile
Expand Down Expand Up @@ -199,6 +200,17 @@ def to_string(self) -> str:
"""Convert the streamed string to a string."""
return str(self)

def truncate(self, length: int) -> str:
"""Truncate the streamed string to the specified length."""
chunks = []
current_length = 0
for chunk in self._chunks:
chunks.append(chunk)
current_length += len(chunk)
if current_length > length:
break
return textwrap.shorten("".join(chunks), width=length)


class AsyncStreamedStr(AsyncIterable[str]):
"""Async version of `StreamedStr`."""
Expand All @@ -213,3 +225,14 @@ async def __aiter__(self) -> AsyncIterator[str]:
async def to_string(self) -> str:
"""Convert the streamed string to a string."""
return "".join([item async for item in self])

async def truncate(self, length: int) -> str:
"""Truncate the streamed string to the specified length."""
chunks = []
current_length = 0
async for chunk in self._chunks:
chunks.append(chunk)
current_length += len(chunk)
if current_length > length:
break
return textwrap.shorten("".join(chunks), width=length)
4 changes: 3 additions & 1 deletion tests/chat_model/test_litellm_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_litellm_chat_model_metadata(litellm_success_callback_calls):
chat_model = LitellmChatModel("gpt-3.5-turbo", metadata={"foo": "bar"})
assert chat_model.metadata == {"foo": "bar"}
chat_model.complete(messages=[UserMessage("Say hello!")])
callback_call = litellm_success_callback_calls[0]
# There are multiple callback calls due to streaming
# Take the last one because the first is occasionally from another test
callback_call = litellm_success_callback_calls[-1]
assert callback_call["kwargs"]["litellm_params"]["metadata"] == {"foo": "bar"}


Expand Down
13 changes: 13 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def test_streamed_str_str():
assert str(streamed_str) == "Hello World"


def test_streamed_str_truncate():
streamed_str = StreamedStr(["First", " Second", " Third"])
assert streamed_str.truncate(length=12) == "First [...]"
assert streamed_str.truncate(length=99) == "First Second Third"


@pytest.mark.asyncio
async def test_async_streamed_str_iter():
aiter_chunks = async_iter(["Hello", " World"])
Expand All @@ -137,3 +143,10 @@ async def test_async_streamed_str_iter():
async def test_async_streamed_str_to_string():
async_streamed_str = AsyncStreamedStr(async_iter(["Hello", " World"]))
assert await async_streamed_str.to_string() == "Hello World"


@pytest.mark.asyncio
async def test_async_streamed_str_truncate():
async_streamed_str = AsyncStreamedStr(async_iter(["First", " Second", " Third"]))
assert await async_streamed_str.truncate(length=12) == "First [...]"
assert await async_streamed_str.truncate(length=99) == "First Second Third"

0 comments on commit b3cff4d

Please sign in to comment.