From 9ad1c7f4e29a249a87e5be1c32933bb6e244adb1 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Mon, 18 Nov 2024 00:04:19 -0800 Subject: [PATCH 01/40] Copy discard_none_arguments to LitellmChatModel --- src/magentic/chat_model/litellm_chat_model.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 68ea72f5..827c28e3 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -1,6 +1,7 @@ from collections.abc import Callable, Iterable, Sequence +from functools import wraps from itertools import chain -from typing import Any, TypeVar, cast, overload +from typing import Any, ParamSpec, TypeVar, cast, overload from openai.types.chat import ChatCompletionToolChoiceOptionParam @@ -25,7 +26,6 @@ FunctionToolSchema, _aparse_streamed_tool_calls, _parse_streamed_tool_calls, - discard_none_arguments, message_to_openai_message, ) from magentic.function_call import ( @@ -48,9 +48,23 @@ raise ImportError(msg) from error +P = ParamSpec("P") R = TypeVar("R") +def discard_none_arguments(func: Callable[P, R]) -> Callable[P, R]: + """Decorator to discard function arguments with value `None`""" + + @wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: + non_none_kwargs = { + key: value for key, value in kwargs.items() if value is not None + } + return func(*args, **non_none_kwargs) # type: ignore[arg-type] + + return wrapped + + class LitellmChatModel(ChatModel): """An LLM chat model that uses the `litellm` python package.""" From 93d6cfdc2632987d0788d67a03e7f52e875d763a Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Mon, 18 Nov 2024 23:14:00 -0800 Subject: [PATCH 02/40] Switch OpenaiChatModel to use Stream approach --- src/magentic/chat_model/openai_chat_model.py | 528 ++++++++----------- 1 file changed, 212 insertions(+), 316 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index c2eaf146..29ac8ebc 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,28 +1,31 @@ import base64 from collections.abc import ( - AsyncIterable, AsyncIterator, + Awaitable, Callable, Iterable, Iterator, Sequence, ) from enum import Enum -from functools import singledispatch, wraps -from itertools import chain, groupby -from typing import Any, Generic, Literal, ParamSpec, TypeVar, cast, overload +from functools import singledispatch +from itertools import chain +from typing import Any, Generic, Literal, TypeVar, cast, overload import filetype import openai +from openai.lib.streaming.chat import ( + AsyncChatCompletionStream, + AsyncChatCompletionStreamManager, + ChatCompletionStream, +) from openai.types.chat import ( ChatCompletionChunk, ChatCompletionMessageParam, - ChatCompletionMessageToolCallParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ) -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from pydantic import ValidationError from magentic.chat_model.base import ( @@ -57,13 +60,8 @@ from magentic.streaming import ( AsyncStreamedStr, StreamedStr, - aapply, achain, - agroupby, - apeek, - apply, async_iter, - peek, ) from magentic.typing import is_any_origin_subclass, is_origin_subclass from magentic.vision import UserImageMessage @@ -238,233 +236,160 @@ def as_tool_choice(self) -> ChatCompletionToolChoiceOptionParam: def to_dict(self) -> ChatCompletionToolParam: return {"type": "function", "function": self._function_schema.dict()} - def matches(self, tool_call: ChoiceDeltaToolCall) -> bool: - return bool( - # TODO: Add back tool_call.type == "function" when LiteLLM Mistral fixed - # https://github.com/BerriAI/litellm/issues/2645 - tool_call.function and self._function_schema.name == tool_call.function.name - ) - - -# TODO: Generalize this to BaseToolSchema when that is created -BeseToolSchemaT = TypeVar("BeseToolSchemaT", bound=BaseFunctionToolSchema[Any]) +class OpenaiStream: + """Converts a stream of openai events into a stream of magentic objects.""" -def select_tool_schema( - tool_call: ChoiceDeltaToolCall, tools_schemas: list[BeseToolSchemaT] -) -> BeseToolSchemaT: - """Select the tool schema based on the response chunk.""" - for tool_schema in tools_schemas: - if tool_schema.matches(tool_call): - return tool_schema - - msg = f"Unknown tool call: {tool_call.model_dump_json()}" - raise ValueError(msg) # TODO: Create `UnknownToolCallError` for this - - -class FunctionToolSchema(BaseFunctionToolSchema[FunctionSchema[T]]): - def parse_tool_call(self, chunks: Iterable[ChoiceDeltaToolCall]) -> T: - return self._function_schema.parse_args( - chunk.function.arguments - for chunk in chunks - if chunk.function and chunk.function.arguments is not None - ) - - -class AsyncFunctionToolSchema(BaseFunctionToolSchema[AsyncFunctionSchema[T]]): - async def aparse_tool_call(self, chunks: AsyncIterable[ChoiceDeltaToolCall]) -> T: - return await self._function_schema.aparse_args( - chunk.function.arguments - async for chunk in chunks - if chunk.function and chunk.function.arguments is not None - ) - - -def _get_tool_call_id_for_chunk(tool_call: ChoiceDeltaToolCall) -> Any: - """Returns an id that is consistent for chunks from the same tool_call.""" - # openai keeps index consistent for chunks from the same tool_call, but id is null - # mistral has null index, but keeps id consistent - return tool_call.index if tool_call.index is not None else tool_call.id - - -def _iter_streamed_tool_calls( - response: Iterable[ChatCompletionChunk], -) -> Iterator[Iterator[ChoiceDeltaToolCall]]: - """Group tool_call chunks into separate iterators.""" - all_tool_call_chunks = ( - tool_call - for chunk in response - if chunk.choices and chunk.choices[0].delta.tool_calls - for tool_call in chunk.choices[0].delta.tool_calls - ) - for _, tool_call_chunks in groupby( - all_tool_call_chunks, _get_tool_call_id_for_chunk - ): - yield tool_call_chunks - - -async def _aiter_streamed_tool_calls( - response: AsyncIterable[ChatCompletionChunk], -) -> AsyncIterator[AsyncIterator[ChoiceDeltaToolCall]]: - """Async version of `_iter_streamed_tool_calls`.""" - all_tool_call_chunks = ( - tool_call - async for chunk in response - if chunk.choices and chunk.choices[0].delta.tool_calls - for tool_call in chunk.choices[0].delta.tool_calls - ) - async for _, tool_call_chunks in agroupby( - all_tool_call_chunks, _get_tool_call_id_for_chunk + def __init__( + self, stream: ChatCompletionStream, function_schemas: list[FunctionSchema[Any]] ): - yield tool_call_chunks - - -def _parse_streamed_tool_calls( - response: Iterable[ChatCompletionChunk], - tool_schemas: list[FunctionToolSchema[T]], -) -> Iterator[T]: - cached_response: list[ChatCompletionChunk] = [] - response = apply(cached_response.append, response) - try: - for tool_call_chunks in _iter_streamed_tool_calls(response): - first_chunk, tool_call_chunks = peek(tool_call_chunks) - tool_schema = select_tool_schema(first_chunk, tool_schemas) - tool_call = tool_schema.parse_tool_call(tool_call_chunks) - yield tool_call - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: - raw_message = _join_streamed_tool_calls_to_message(cached_response) - raise ToolSchemaParseError( - output_message=raw_message, - tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e - - -async def _aparse_streamed_tool_calls( - response: AsyncIterable[ChatCompletionChunk], - tool_schemas: list[AsyncFunctionToolSchema[T]], -) -> AsyncIterator[T]: - cached_response: list[ChatCompletionChunk] = [] - response = aapply(cached_response.append, response) - try: - async for tool_call_chunks in _aiter_streamed_tool_calls(response): - first_chunk, tool_call_chunks = await apeek(tool_call_chunks) - tool_schema = select_tool_schema(first_chunk, tool_schemas) - tool_call = await tool_schema.aparse_tool_call(tool_call_chunks) - yield tool_call - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: - raw_message = _join_streamed_tool_calls_to_message(cached_response) - raise ToolSchemaParseError( - output_message=raw_message, - tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e - - -def _join_streamed_tool_call( - tool_call_deltas: Iterable[ChoiceDeltaToolCall], -) -> ChatCompletionMessageToolCallParam: - """Join chunks from a single streamed tool call into an OpenAI tool call dict.""" - tool_id: str | None = None - tool_type: Literal["function"] = "function" - function_name: list[str] = [] - function_arguments: list[str] = [] - for tool_call_delta in tool_call_deltas: - if tool_call_delta.id: - tool_id = tool_call_delta.id - if tool_call_delta.type: - tool_type = tool_call_delta.type - if tool_call_delta.function: - if tool_call_delta.function.name: - function_name.append(tool_call_delta.function.name) - if tool_call_delta.function.arguments: - function_arguments.append(tool_call_delta.function.arguments) - return { - "id": tool_id or _create_unique_id(), - "type": tool_type, - "function": { - "name": "".join(function_name), - "arguments": "".join(function_arguments), - }, - } - - -def _join_streamed_tool_calls_to_message( - response: Iterable[ChatCompletionChunk], - # TODO: Type as ChatCompletionAssistantMessageParam. Issue: https://github.com/pydantic/pydantic/issues/10105 -) -> _RawMessage[Any]: - """Join streamed tool calls into an OpenAI chat completion message.""" - return _RawMessage( - { - "role": OpenaiMessageRole.ASSISTANT.value, - "content": None, - "tool_calls": [ - _join_streamed_tool_call(tool_call_chunks) - for tool_call_chunks in _iter_streamed_tool_calls(response) - ], - } - ) - - -def _create_usage_ref( - response: Iterable[ChatCompletionChunk], -) -> tuple[list[Usage], Iterator[ChatCompletionChunk]]: - """Returns a pointer to a Usage object that is created at the end of the response.""" - usage_ref: list[Usage] = [] - - def generator( - response: Iterable[ChatCompletionChunk], - ) -> Iterator[ChatCompletionChunk]: - for chunk in response: - if chunk.usage: - usage = Usage( - input_tokens=chunk.usage.prompt_tokens, - output_tokens=chunk.usage.completion_tokens, + self._stream = stream + self._function_schemas = function_schemas + self._iterator = self.__stream__() + self.usage: Usage | None = None + + def __next__(self) -> StreamedStr | FunctionCall: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[StreamedStr | FunctionCall]: + yield from self._iterator + + def __stream__(self) -> Iterator[StreamedStr | FunctionCall]: + transition = [next(self._stream)] + + def _streamed_str(stream: Iterator) -> StreamedStr: + def _group(stream: Iterator) -> Iterator: + for event in stream: + if event.type == "content.delta": + yield event.delta + elif event.type == "content.done": + transition.append(event) + return + + return StreamedStr(_group(stream)) + + def _function_call(transition_item, stream: Iterator) -> FunctionCall: + def _group(stream: Iterator) -> Iterator: + for event in stream: + if event.type == "tool_calls.function.arguments.delta": + yield event.arguments_delta + elif event.type == "tool_calls.function.arguments.done": + transition.append(event) + return + + # TODO: Tidy matching function schema. Include Mistral fix + for function_schema in self._function_schemas: + if function_schema.name == transition_item.name: + break + # TODO: Catch/raise unknown tool call error here + try: # TODO: Tidy catching of error here to DRY with async + return function_schema.parse_args(_group(stream)) + except ValidationError as e: + raw_message = self._stream.current_completion_snapshot.choices[ + 0 + ].message.model_dump() + raise ToolSchemaParseError( + output_message=_RawMessage(raw_message), + tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] + validation_error=e, + ) from e + + while transition: + transition_item = transition.pop() + if transition_item.type == "content.delta": + yield _streamed_str(self._stream) + elif transition_item.type == "tool_calls.function.arguments.delta": + yield _function_call(transition_item, self._stream) + elif transition_item.type == "chunk" and transition_item.chunk.usage: + self.usage = Usage( + input_tokens=transition_item.chunk.usage.prompt_tokens, + output_tokens=transition_item.chunk.usage.completion_tokens, ) - usage_ref.append(usage) - yield chunk - - return usage_ref, generator(response) - - -def _create_usage_ref_async( - response: AsyncIterable[ChatCompletionChunk], -) -> tuple[list[Usage], AsyncIterator[ChatCompletionChunk]]: - """Async version of `_create_usage_ref`.""" - usage_ref: list[Usage] = [] - - async def agenerator( - response: AsyncIterable[ChatCompletionChunk], - ) -> AsyncIterator[ChatCompletionChunk]: - async for chunk in response: - if chunk.usage: - usage = Usage( - input_tokens=chunk.usage.prompt_tokens, - output_tokens=chunk.usage.completion_tokens, - ) - usage_ref.append(usage) - yield chunk + elif new_transition_item := next(self._stream, None): + transition.append(new_transition_item) - return usage_ref, agenerator(response) + def close(self): + self._stream.close() -P = ParamSpec("P") -R = TypeVar("R") +class OpenaiAsyncStream: + """Converts an async stream of openai events into an async stream of magentic objects.""" + def __init__( + self, + stream: AsyncChatCompletionStream, + function_schemas: list[AsyncFunctionSchema[Any]], + ): + self._stream = stream + self._function_schemas = function_schemas + self._aiterator = self.__stream__() + self.usage: Usage | None = None + + async def __anext__(self) -> AsyncStreamedStr | FunctionCall: + return await self._aiterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: + async for item in self._aiterator: + yield item + + async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: + transition = [await anext(self._stream)] + + def _streamed_str(stream: AsyncIterator) -> AsyncStreamedStr: + async def _group(stream: AsyncIterator) -> AsyncIterator: + async for event in stream: + if event.type == "content.delta": + yield event.delta + elif event.type == "content.done": + transition.append(event) + return + + return AsyncStreamedStr(_group(stream)) + + async def _function_call( + transition_item, stream: AsyncIterator + ) -> FunctionCall: + async def _group(stream: AsyncIterator) -> AsyncIterator: + async for event in stream: + if event.type == "tool_calls.function.arguments.delta": + yield event.arguments_delta + elif event.type == "tool_calls.function.arguments.done": + transition.append(event) + return + + # TODO: Tidy matching function schema. Include Mistral fix + for function_schema in self._function_schemas: + if function_schema.name == transition_item.name: + break + # TODO: Catch/raise unknown tool call error here + try: # TODO: Tidy catching of error here to DRY with async + return await function_schema.aparse_args(_group(stream)) + except ValidationError as e: + raw_message = self._stream.current_completion_snapshot.choices[ + 0 + ].message.model_dump() + raise ToolSchemaParseError( + output_message=_RawMessage(raw_message), + tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] + validation_error=e, + ) from e + + while transition: + transition_item = transition.pop() + if transition_item.type == "content.delta": + yield _streamed_str(self._stream) + elif transition_item.type == "tool_calls.function.arguments.delta": + yield await _function_call(transition_item, self._stream) + elif transition_item.type == "chunk" and transition_item.chunk.usage: + self.usage = Usage( + input_tokens=transition_item.chunk.usage.prompt_tokens, + output_tokens=transition_item.chunk.usage.completion_tokens, + ) + elif new_transition_item := await anext(self._stream, None): + transition.append(new_transition_item) -def discard_none_arguments(func: Callable[P, R]) -> Callable[P, R]: - """Decorator to discard function arguments with value `None`""" - - @wraps(func) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: - non_none_kwargs = { - key: value for key, value in kwargs.items() if value is not None - } - return func(*args, **non_none_kwargs) # type: ignore[arg-type] - - return wrapped + async def close(self): + await self._stream.close() STR_OR_FUNCTIONCALL_TYPE = ( @@ -476,6 +401,8 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: AsyncParallelFunctionCall, ) +R = TypeVar("R") + class OpenaiChatModel(ChatModel): """An LLM chat model that uses the `openai` python package.""" @@ -612,15 +539,13 @@ def complete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [FunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) allow_string_output = str_in_output_types or streamed_str_in_output_types - response: Iterator[ChatCompletionChunk] = discard_none_arguments( - self._client.chat.completions.create - )( + _stream = self._client.beta.chat.completions.stream( model=self.model, messages=_add_missing_tool_calls_responses( [message_to_openai_message(m) for m in messages] @@ -628,7 +553,6 @@ def complete( max_tokens=self.max_tokens, seed=self.seed, stop=stop, - stream=True, stream_options=self._get_stream_options(), temperature=self.temperature, tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, @@ -638,45 +562,27 @@ def complete( parallel_tool_calls=self._get_parallel_tool_calls( tools_specified=bool(tool_schemas), output_types=output_types ), - ) - usage_ref, response = _create_usage_ref(response) - - first_chunk = next(response) - # Azure OpenAI sends a chunk with empty choices first - if len(first_chunk.choices) == 0: - first_chunk = next(response) - if ( - # Mistral tool call first chunk has content "" - not first_chunk.choices[0].delta.content - and not first_chunk.choices[0].delta.tool_calls - ): - first_chunk = next(response) - response = chain([first_chunk], response) - - if first_chunk.choices[0].delta.content: - streamed_str = StreamedStr( - chunk.choices[0].delta.content - for chunk in response - if chunk.choices and chunk.choices[0].delta.content is not None - ) + ).__enter__() # Get stream directly, without context manager + stream = OpenaiStream(_stream, function_schemas=function_schemas) + + # TODO: Function to validate LLM output against prompt-function return type + first_response_obj = next(stream) + if isinstance(first_response_obj, StreamedStr): str_content = validate_str_content( - streamed_str, + first_response_obj, allow_string_output=allow_string_output, streamed=streamed_str_in_output_types, ) - return AssistantMessage._with_usage(str_content, usage_ref) # type: ignore[return-value] + return AssistantMessage(str_content) # type: ignore[return-value] - if first_chunk.choices[0].delta.tool_calls: - tool_calls = _parse_streamed_tool_calls(response, tool_schemas) + if isinstance(first_response_obj, FunctionCall): if is_any_origin_subclass(output_types, ParallelFunctionCall): - content = ParallelFunctionCall(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] + content = ParallelFunctionCall(chain([first_response_obj], stream)) + return AssistantMessage(content) # type: ignore[return-value] # Take only the first tool_call, silently ignore extra chunks - content = next(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] + return AssistantMessage(first_response_obj) # type: ignore[return-value] - msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" - raise ValueError(msg) + return AssistantMessage(first_response_obj) @overload async def acomplete( @@ -715,7 +621,7 @@ async def acomplete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [AsyncFunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) async_streamed_str_in_output_types = is_any_origin_subclass( @@ -723,62 +629,52 @@ async def acomplete( ) allow_string_output = str_in_output_types or async_streamed_str_in_output_types - response: AsyncIterator[ChatCompletionChunk] = await discard_none_arguments( - self._async_client.chat.completions.create - )( - model=self.model, - messages=_add_missing_tool_calls_responses( - [message_to_openai_message(m) for m in messages] - ), - max_tokens=self.max_tokens, - seed=self.seed, - stop=stop, - stream=True, - stream_options=self._get_stream_options(), - temperature=self.temperature, - tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, - tool_choice=self._get_tool_choice( - tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), - parallel_tool_calls=self._get_parallel_tool_calls( - tools_specified=bool(tool_schemas), output_types=output_types - ), - ) - usage_ref, response = _create_usage_ref_async(response) - - first_chunk = await anext(response) - # Azure OpenAI sends a chunk with empty choices first - if len(first_chunk.choices) == 0: - first_chunk = await anext(response) - if ( - # Mistral tool call first chunk has content "" - not first_chunk.choices[0].delta.content - and not first_chunk.choices[0].delta.tool_calls - ): - first_chunk = await anext(response) - response = achain(async_iter([first_chunk]), response) - - if first_chunk.choices[0].delta.content: - async_streamed_str = AsyncStreamedStr( - chunk.choices[0].delta.content - async for chunk in response - if chunk.choices and chunk.choices[0].delta.content is not None + response: Awaitable[AsyncIterator[ChatCompletionChunk]] = ( + self._async_client.chat.completions.create( + model=self.model, + messages=_add_missing_tool_calls_responses( + [message_to_openai_message(m) for m in messages] + ), + max_tokens=self.max_tokens, + seed=self.seed, + stop=stop, + stream=True, + stream_options=self._get_stream_options(), + temperature=self.temperature, + tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, + tool_choice=self._get_tool_choice( + tool_schemas=tool_schemas, allow_string_output=allow_string_output + ), + parallel_tool_calls=self._get_parallel_tool_calls( + tools_specified=bool(tool_schemas), output_types=output_types + ), ) + ) + _stream = await AsyncChatCompletionStreamManager( + response, + response_format=openai.NOT_GIVEN, + input_tools=[schema.to_dict() for schema in tool_schemas] + or openai.NOT_GIVEN, + ).__aenter__() # Get stream directly, without context manager + stream = OpenaiAsyncStream(_stream, function_schemas=function_schemas) + + # TODO: Function to validate LLM output against prompt-function return type + first_response_obj = await anext(stream) + if isinstance(first_response_obj, AsyncStreamedStr): str_content = await avalidate_str_content( - async_streamed_str, + first_response_obj, allow_string_output=allow_string_output, streamed=async_streamed_str_in_output_types, ) - return AssistantMessage._with_usage(str_content, usage_ref) # type: ignore[return-value] + return AssistantMessage(str_content) # type: ignore[return-value] - if first_chunk.choices[0].delta.tool_calls: - tool_calls = _aparse_streamed_tool_calls(response, tool_schemas) + if isinstance(first_response_obj, FunctionCall): if is_any_origin_subclass(output_types, AsyncParallelFunctionCall): - content = AsyncParallelFunctionCall(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] + content = AsyncParallelFunctionCall( + achain(async_iter([first_response_obj]), stream) + ) + return AssistantMessage(content) # type: ignore[return-value] # Take only the first tool_call, silently ignore extra chunks - content = await anext(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] + return AssistantMessage(first_response_obj) # type: ignore[return-value] - msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" - raise ValueError(msg) + return AssistantMessage(first_response_obj) From a7ec635ff1af90c54bf50607dd1bda1f6539a63d Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 19 Nov 2024 21:38:14 -0800 Subject: [PATCH 03/40] Add TODO to test retry logic --- src/magentic/chat_model/openai_chat_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 29ac8ebc..8de135b9 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -362,6 +362,7 @@ async def _group(stream: AsyncIterator) -> AsyncIterator: if function_schema.name == transition_item.name: break # TODO: Catch/raise unknown tool call error here + # TODO: Test that retry logic still works try: # TODO: Tidy catching of error here to DRY with async return await function_schema.aparse_args(_group(stream)) except ValidationError as e: From 02a05ec2e963af5fd33c1ddb73881e282650bc77 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:17:10 -0800 Subject: [PATCH 04/40] Add StreamParser to DRY openai streaming code --- src/magentic/chat_model/openai_chat_model.py | 241 +++++++++++-------- 1 file changed, 145 insertions(+), 96 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 8de135b9..7e83485f 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,4 +1,5 @@ import base64 +from abc import ABC, abstractmethod from collections.abc import ( AsyncIterator, Awaitable, @@ -10,7 +11,7 @@ from enum import Enum from functools import singledispatch from itertools import chain -from typing import Any, Generic, Literal, TypeVar, cast, overload +from typing import Any, Generic, Literal, TypeGuard, TypeVar, cast, overload import filetype import openai @@ -18,6 +19,12 @@ AsyncChatCompletionStream, AsyncChatCompletionStreamManager, ChatCompletionStream, + ChatCompletionStreamEvent, + ChunkEvent, + ContentDeltaEvent, + ContentDoneEvent, + FunctionToolCallArgumentsDeltaEvent, + FunctionToolCallArgumentsDoneEvent, ) from openai.types.chat import ( ChatCompletionChunk, @@ -26,11 +33,9 @@ ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ) -from pydantic import ValidationError from magentic.chat_model.base import ( ChatModel, - ToolSchemaParseError, avalidate_str_content, validate_str_content, ) @@ -237,6 +242,97 @@ def to_dict(self) -> ChatCompletionToolParam: return {"type": "function", "function": self._function_schema.dict()} +ItemT = TypeVar("ItemT") +OutputT = TypeVar("OutputT") + + +class StreamParser(ABC, Generic[ItemT, OutputT]): + """Filters and transforms items from an iterator until the end condition is met.""" + + def is_member(self, item: ItemT) -> bool: + return True + + @abstractmethod + def is_end(self, item: ItemT) -> bool: ... + + @abstractmethod + def transform(self, item: ItemT) -> OutputT: ... + + def iter( + self, iterator: Iterator[ItemT], transition: list[ItemT] + ) -> Iterator[OutputT]: + for item in iterator: + if self.is_member(item): + yield self.transform(item) + if self.is_end(item): + assert not transition # noqa: S101 + transition.append(item) + return + + async def aiter( + self, aiterator: AsyncIterator[ItemT], transition: list[ItemT] + ) -> AsyncIterator[OutputT]: + async for item in aiterator: + if self.is_member(item): + yield self.transform(item) + if self.is_end(item): + assert not transition # noqa: S101 + transition.append(item) + return + + +class OpenaiContentStreamParser(StreamParser[ChatCompletionStreamEvent, str]): + """Filters and transforms OpenAI content events from a stream.""" + + def is_member( + self, item: ChatCompletionStreamEvent + ) -> TypeGuard[ContentDeltaEvent]: + return item.type == "content.delta" + + def is_end(self, item: ChatCompletionStreamEvent) -> TypeGuard[ContentDoneEvent]: + return item.type == "content.done" + + def transform(self, item: ChatCompletionStreamEvent) -> str: + assert self.is_member(item) # noqa: S101 + return item.delta + + +class OpenaiToolStreamParser(StreamParser[ChatCompletionStreamEvent, str]): + """Filters and transforms OpenAI tool events from a stream.""" + + def is_member( + self, item: ChatCompletionStreamEvent + ) -> TypeGuard[FunctionToolCallArgumentsDeltaEvent]: + return item.type == "tool_calls.function.arguments.delta" + + def is_end( + self, item: ChatCompletionStreamEvent + ) -> TypeGuard[FunctionToolCallArgumentsDoneEvent]: + return item.type == "tool_calls.function.arguments.done" + + def transform(self, item: ChatCompletionStreamEvent) -> str: + assert self.is_member(item) # noqa: S101 + return item.arguments_delta + + +class OpenaiUsageStreamParser(StreamParser[ChatCompletionStreamEvent, Usage]): + """Filters and transforms OpenAI usage events from a stream.""" + + def is_member(self, item: ChatCompletionStreamEvent) -> TypeGuard[ChunkEvent]: + return item.type == "chunk" and bool(item.chunk.usage) + + def is_end(self, item: ChatCompletionStreamEvent) -> Literal[True]: + return True # Single event so immediately end + + def transform(self, item: ChatCompletionStreamEvent) -> Usage: + assert self.is_member(item) # noqa: S101 + assert item.chunk.usage # noqa: S101 + return Usage( + input_tokens=item.chunk.usage.prompt_tokens, + output_tokens=item.chunk.usage.completion_tokens, + ) + + class OpenaiStream: """Converts a stream of openai events into a stream of magentic objects.""" @@ -256,55 +352,33 @@ def __iter__(self) -> Iterator[StreamedStr | FunctionCall]: def __stream__(self) -> Iterator[StreamedStr | FunctionCall]: transition = [next(self._stream)] - - def _streamed_str(stream: Iterator) -> StreamedStr: - def _group(stream: Iterator) -> Iterator: - for event in stream: - if event.type == "content.delta": - yield event.delta - elif event.type == "content.done": - transition.append(event) - return - - return StreamedStr(_group(stream)) - - def _function_call(transition_item, stream: Iterator) -> FunctionCall: - def _group(stream: Iterator) -> Iterator: - for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta - elif event.type == "tool_calls.function.arguments.done": - transition.append(event) - return - - # TODO: Tidy matching function schema. Include Mistral fix - for function_schema in self._function_schemas: - if function_schema.name == transition_item.name: - break - # TODO: Catch/raise unknown tool call error here - try: # TODO: Tidy catching of error here to DRY with async - return function_schema.parse_args(_group(stream)) - except ValidationError as e: - raw_message = self._stream.current_completion_snapshot.choices[ - 0 - ].message.model_dump() - raise ToolSchemaParseError( - output_message=_RawMessage(raw_message), - tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e + content_parser = OpenaiContentStreamParser() + tool_parser = OpenaiToolStreamParser() + usage_parser = OpenaiUsageStreamParser() while transition: transition_item = transition.pop() - if transition_item.type == "content.delta": - yield _streamed_str(self._stream) - elif transition_item.type == "tool_calls.function.arguments.delta": - yield _function_call(transition_item, self._stream) - elif transition_item.type == "chunk" and transition_item.chunk.usage: - self.usage = Usage( - input_tokens=transition_item.chunk.usage.prompt_tokens, - output_tokens=transition_item.chunk.usage.completion_tokens, + if content_parser.is_member(transition_item): + yield StreamedStr(content_parser.iter(self._stream, transition)) + elif tool_parser.is_member(transition_item): + # TODO: Tidy matching function schema. Include Mistral fix + # tool_parser.select_function_schema() ? + function_schema = next( + ( + function_schema + for function_schema in self._function_schemas + if function_schema.name == transition_item.name + ), + None, ) + # TODO: Catch/raise unknown tool call error here + assert function_schema is not None # noqa: S101 + # TODO: Catch/raise ToolSchemaParseError here for retry logic + yield function_schema.parse_args( + tool_parser.iter(self._stream, transition) + ) + elif usage_parser.is_member(transition_item): + self.usage = usage_parser.transform(transition_item) elif new_transition_item := next(self._stream, None): transition.append(new_transition_item) @@ -334,58 +408,33 @@ async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: transition = [await anext(self._stream)] - - def _streamed_str(stream: AsyncIterator) -> AsyncStreamedStr: - async def _group(stream: AsyncIterator) -> AsyncIterator: - async for event in stream: - if event.type == "content.delta": - yield event.delta - elif event.type == "content.done": - transition.append(event) - return - - return AsyncStreamedStr(_group(stream)) - - async def _function_call( - transition_item, stream: AsyncIterator - ) -> FunctionCall: - async def _group(stream: AsyncIterator) -> AsyncIterator: - async for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta - elif event.type == "tool_calls.function.arguments.done": - transition.append(event) - return - - # TODO: Tidy matching function schema. Include Mistral fix - for function_schema in self._function_schemas: - if function_schema.name == transition_item.name: - break - # TODO: Catch/raise unknown tool call error here - # TODO: Test that retry logic still works - try: # TODO: Tidy catching of error here to DRY with async - return await function_schema.aparse_args(_group(stream)) - except ValidationError as e: - raw_message = self._stream.current_completion_snapshot.choices[ - 0 - ].message.model_dump() - raise ToolSchemaParseError( - output_message=_RawMessage(raw_message), - tool_call_id=raw_message.content["tool_calls"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e + content_parser = OpenaiContentStreamParser() + tool_parser = OpenaiToolStreamParser() + usage_parser = OpenaiUsageStreamParser() while transition: transition_item = transition.pop() - if transition_item.type == "content.delta": - yield _streamed_str(self._stream) - elif transition_item.type == "tool_calls.function.arguments.delta": - yield await _function_call(transition_item, self._stream) - elif transition_item.type == "chunk" and transition_item.chunk.usage: - self.usage = Usage( - input_tokens=transition_item.chunk.usage.prompt_tokens, - output_tokens=transition_item.chunk.usage.completion_tokens, + if content_parser.is_member(transition_item): + yield AsyncStreamedStr(content_parser.aiter(self._stream, transition)) + elif tool_parser.is_member(transition_item): + # TODO: Tidy matching function schema. Include Mistral fix + # tool_parser.select_function_schema() ? + function_schema = next( + ( + function_schema + for function_schema in self._function_schemas + if function_schema.name == transition_item.name + ), + None, + ) + # TODO: Catch/raise unknown tool call error here + assert function_schema is not None # noqa: S101 + # TODO: Catch/raise ToolSchemaParseError here for retry logic + yield await function_schema.aparse_args( + tool_parser.aiter(self._stream, transition) ) + elif usage_parser.is_member(transition_item): + self.usage = usage_parser.transform(transition_item) elif new_transition_item := await anext(self._stream, None): transition.append(new_transition_item) From a097c2a1b641dcf2aabff9e8bdbc532afde3cb4e Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Wed, 20 Nov 2024 21:53:18 -0800 Subject: [PATCH 05/40] Add _if_given --- src/magentic/chat_model/openai_chat_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 7e83485f..ec3eaeb9 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -442,6 +442,10 @@ async def close(self): await self._stream.close() +def _if_given(value: T | None) -> T | openai.NotGiven: + return value if value is not None else openai.NOT_GIVEN + + STR_OR_FUNCTIONCALL_TYPE = ( str, StreamedStr, @@ -600,11 +604,11 @@ def complete( messages=_add_missing_tool_calls_responses( [message_to_openai_message(m) for m in messages] ), - max_tokens=self.max_tokens, - seed=self.seed, - stop=stop, + max_tokens=_if_given(self.max_tokens), + seed=_if_given(self.seed), + stop=_if_given(stop), stream_options=self._get_stream_options(), - temperature=self.temperature, + temperature=_if_given(self.temperature), tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output @@ -685,12 +689,12 @@ async def acomplete( messages=_add_missing_tool_calls_responses( [message_to_openai_message(m) for m in messages] ), - max_tokens=self.max_tokens, - seed=self.seed, - stop=stop, + max_tokens=_if_given(self.max_tokens), + seed=_if_given(self.seed), + stop=_if_given(stop), stream=True, stream_options=self._get_stream_options(), - temperature=self.temperature, + temperature=_if_given(self.temperature), tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output From f4bb4c1f2bbc55776122e2b65d7ed67cc26f016a Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Fri, 22 Nov 2024 22:29:33 -0800 Subject: [PATCH 06/40] Fix typing for openai Stream classes --- src/magentic/chat_model/openai_chat_model.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index ec3eaeb9..2b9f3342 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -333,24 +333,24 @@ def transform(self, item: ChatCompletionStreamEvent) -> Usage: ) -class OpenaiStream: +class OpenaiStream(Generic[T]): """Converts a stream of openai events into a stream of magentic objects.""" def __init__( - self, stream: ChatCompletionStream, function_schemas: list[FunctionSchema[Any]] + self, stream: ChatCompletionStream, function_schemas: list[FunctionSchema[T]] ): self._stream = stream self._function_schemas = function_schemas self._iterator = self.__stream__() self.usage: Usage | None = None - def __next__(self) -> StreamedStr | FunctionCall: + def __next__(self) -> StreamedStr | T: return self._iterator.__next__() - def __iter__(self) -> Iterator[StreamedStr | FunctionCall]: + def __iter__(self) -> Iterator[StreamedStr | T]: yield from self._iterator - def __stream__(self) -> Iterator[StreamedStr | FunctionCall]: + def __stream__(self) -> Iterator[StreamedStr | T]: transition = [next(self._stream)] content_parser = OpenaiContentStreamParser() tool_parser = OpenaiToolStreamParser() @@ -386,27 +386,27 @@ def close(self): self._stream.close() -class OpenaiAsyncStream: +class OpenaiAsyncStream(Generic[T]): """Converts an async stream of openai events into an async stream of magentic objects.""" def __init__( self, stream: AsyncChatCompletionStream, - function_schemas: list[AsyncFunctionSchema[Any]], + function_schemas: list[AsyncFunctionSchema[T]], ): self._stream = stream self._function_schemas = function_schemas self._aiterator = self.__stream__() self.usage: Usage | None = None - async def __anext__(self) -> AsyncStreamedStr | FunctionCall: + async def __anext__(self) -> AsyncStreamedStr | T: return await self._aiterator.__anext__() - async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: + async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | T]: async for item in self._aiterator: yield item - async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | FunctionCall]: + async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: transition = [await anext(self._stream)] content_parser = OpenaiContentStreamParser() tool_parser = OpenaiToolStreamParser() From 8ccd016646cc7a565624b2b1b15331e943a30b8b Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Fri, 22 Nov 2024 23:35:09 -0800 Subject: [PATCH 07/40] Add parse_stream and use in OpenaiChatModel --- src/magentic/chat_model/base.py | 58 +++++++++++++++++++- src/magentic/chat_model/openai_chat_model.py | 49 ++--------------- src/magentic/typing.py | 13 +++++ 3 files changed, 73 insertions(+), 47 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 8998e76c..94f955b3 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -2,12 +2,19 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from contextvars import ContextVar -from typing import Any, TypeVar, overload +from itertools import chain +from typing import Any, AsyncIterator, Iterator, TypeVar, cast, overload from pydantic import ValidationError from magentic.chat_model.message import AssistantMessage, Message -from magentic.streaming import AsyncStreamedStr, StreamedStr +from magentic.function_call import ( + AsyncParallelFunctionCall, + FunctionCall, + ParallelFunctionCall, +) +from magentic.streaming import AsyncStreamedStr, StreamedStr, achain, async_iter +from magentic.typing import is_instance_origin R = TypeVar("R") @@ -76,6 +83,53 @@ async def avalidate_str_content( return await async_streamed_str.to_string() +# TODO: Make this a stream class with a close method and context management +def parse_stream(stream: Iterator[Any], output_types: list[type[R]]) -> R: + """Parse and validate the LLM output stream against the allowed output types.""" + # TODO: option to error/warn/ignore extra objects + # TODO: warn for degenerate output types ? + obj = next(stream) + # TODO: Add type for mixed StreamedStr and FunctionCalls + if isinstance(obj, StreamedStr): + if StreamedStr in output_types: + return cast(R, obj) + if str in output_types: + return cast(R, str(obj)) + model_output = obj.truncate(100) + raise StringNotAllowedError(AssistantMessage(model_output)) + if isinstance(obj, FunctionCall): + if ParallelFunctionCall in output_types: + return cast(R, ParallelFunctionCall(chain([obj], stream))) + if FunctionCall in output_types: + # TODO: Check that FunctionCall type matches ? + return cast(R, obj) + raise ValueError("FunctionCall not allowed") + if is_instance_origin(obj, tuple(output_types)): + return obj + raise ValueError(f"Unexpected output type: {type(obj)}") + + +async def aparse_stream(stream: AsyncIterator[Any], output_types: list[type[R]]) -> R: + """Async version of `parse_stream`.""" + obj = await anext(stream) + if isinstance(obj, AsyncStreamedStr): + if AsyncStreamedStr in output_types: + return cast(R, obj) + if str in output_types: + return cast(R, await obj.to_string()) + model_output = await obj.truncate(100) + raise StringNotAllowedError(AssistantMessage(model_output)) + if isinstance(obj, FunctionCall): + if AsyncParallelFunctionCall in output_types: + return cast(R, AsyncParallelFunctionCall(achain(async_iter([obj]), stream))) + if FunctionCall in output_types: + return cast(R, obj) + raise ValueError("FunctionCall not allowed") + if is_instance_origin(obj, tuple(output_types)): + return obj + raise ValueError(f"Unexpected output type: {type(obj)}") + + class ChatModel(ABC): """An LLM chat model.""" diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 2b9f3342..5842ecd1 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -10,7 +10,6 @@ ) from enum import Enum from functools import singledispatch -from itertools import chain from typing import Any, Generic, Literal, TypeGuard, TypeVar, cast, overload import filetype @@ -36,8 +35,8 @@ from magentic.chat_model.base import ( ChatModel, - avalidate_str_content, - validate_str_content, + aparse_stream, + parse_stream, ) from magentic.chat_model.function_schema import ( AsyncFunctionSchema, @@ -65,8 +64,6 @@ from magentic.streaming import ( AsyncStreamedStr, StreamedStr, - achain, - async_iter, ) from magentic.typing import is_any_origin_subclass, is_origin_subclass from magentic.vision import UserImageMessage @@ -618,25 +615,7 @@ def complete( ), ).__enter__() # Get stream directly, without context manager stream = OpenaiStream(_stream, function_schemas=function_schemas) - - # TODO: Function to validate LLM output against prompt-function return type - first_response_obj = next(stream) - if isinstance(first_response_obj, StreamedStr): - str_content = validate_str_content( - first_response_obj, - allow_string_output=allow_string_output, - streamed=streamed_str_in_output_types, - ) - return AssistantMessage(str_content) # type: ignore[return-value] - - if isinstance(first_response_obj, FunctionCall): - if is_any_origin_subclass(output_types, ParallelFunctionCall): - content = ParallelFunctionCall(chain([first_response_obj], stream)) - return AssistantMessage(content) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - return AssistantMessage(first_response_obj) # type: ignore[return-value] - - return AssistantMessage(first_response_obj) + return AssistantMessage(parse_stream(stream, output_types)) # type: ignore @overload async def acomplete( @@ -711,24 +690,4 @@ async def acomplete( or openai.NOT_GIVEN, ).__aenter__() # Get stream directly, without context manager stream = OpenaiAsyncStream(_stream, function_schemas=function_schemas) - - # TODO: Function to validate LLM output against prompt-function return type - first_response_obj = await anext(stream) - if isinstance(first_response_obj, AsyncStreamedStr): - str_content = await avalidate_str_content( - first_response_obj, - allow_string_output=allow_string_output, - streamed=async_streamed_str_in_output_types, - ) - return AssistantMessage(str_content) # type: ignore[return-value] - - if isinstance(first_response_obj, FunctionCall): - if is_any_origin_subclass(output_types, AsyncParallelFunctionCall): - content = AsyncParallelFunctionCall( - achain(async_iter([first_response_obj]), stream) - ) - return AssistantMessage(content) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - return AssistantMessage(first_response_obj) # type: ignore[return-value] - - return AssistantMessage(first_response_obj) + return AssistantMessage(aparse_stream(stream, output_types)) # type: ignore diff --git a/src/magentic/typing.py b/src/magentic/typing.py index 21045e92..c14400fa 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -17,6 +17,7 @@ def is_union_type(type_: type) -> bool: return type_ is Union or type_ is types.UnionType +T = TypeVar("T") TypeT = TypeVar("TypeT", bound=type) @@ -39,6 +40,18 @@ def is_origin_subclass( return issubclass(get_origin(type_) or type_, cls_or_tuple) +def is_instance_origin( + obj: Any, cls_or_tuple: type[T] | tuple[type[T], ...] +) -> TypeGuard[T]: + cls_or_tuple_origin = ( + tuple(get_origin(cls) or cls for cls in cls_or_tuple) + if isinstance(cls_or_tuple, tuple) + else get_origin(cls_or_tuple) or cls_or_tuple + ) + return isinstance(obj, cls_or_tuple_origin) + + +# TODO: Remove once unused def is_any_origin_subclass( types: Iterable[type], cls_or_tuple: TypeT | tuple[TypeT, ...] ) -> bool: From eeb765e627e43ca3ab1d2e0c66f1dde1fccf9599 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 00:14:07 -0800 Subject: [PATCH 08/40] tidy function schema matching --- src/magentic/chat_model/function_schema.py | 14 ++++++++ src/magentic/chat_model/openai_chat_model.py | 34 +++++++------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/magentic/chat_model/function_schema.py b/src/magentic/chat_model/function_schema.py index df74efc3..40e27785 100644 --- a/src/magentic/chat_model/function_schema.py +++ b/src/magentic/chat_model/function_schema.py @@ -59,6 +59,20 @@ def dict(self) -> FunctionDefinition: return schema +BaseFunctionSchemaT = TypeVar("BaseFunctionSchemaT", bound=BaseFunctionSchema[Any]) + + +def select_function_schema( + function_schemas: Iterable[BaseFunctionSchemaT], name: str +) -> BaseFunctionSchemaT: + """Select the function schema with the given name.""" + for schema in function_schemas: + if schema.name == name: + return schema + # TODO: Catch/raise unknown tool call error here + raise ValueError(f"No function schema found for name {name}") + + class AsyncFunctionSchema(BaseFunctionSchema[T], Generic[T]): @abstractmethod async def aparse_args(self, chunks: AsyncIterable[str]) -> T: diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 5842ecd1..d31fb03c 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -45,6 +45,7 @@ FunctionSchema, async_function_schema_for_type, function_schema_for_type, + select_function_schema, ) from magentic.chat_model.message import ( AssistantMessage, @@ -311,6 +312,10 @@ def transform(self, item: ChatCompletionStreamEvent) -> str: assert self.is_member(item) # noqa: S101 return item.arguments_delta + def get_tool_name(self, item: ChatCompletionStreamEvent) -> str: + assert self.is_member(item) + return item.name + class OpenaiUsageStreamParser(StreamParser[ChatCompletionStreamEvent, Usage]): """Filters and transforms OpenAI usage events from a stream.""" @@ -358,18 +363,10 @@ def __stream__(self) -> Iterator[StreamedStr | T]: if content_parser.is_member(transition_item): yield StreamedStr(content_parser.iter(self._stream, transition)) elif tool_parser.is_member(transition_item): - # TODO: Tidy matching function schema. Include Mistral fix - # tool_parser.select_function_schema() ? - function_schema = next( - ( - function_schema - for function_schema in self._function_schemas - if function_schema.name == transition_item.name - ), - None, + tool_name = tool_parser.get_tool_name(transition_item) + function_schema = select_function_schema( + self._function_schemas, tool_name ) - # TODO: Catch/raise unknown tool call error here - assert function_schema is not None # noqa: S101 # TODO: Catch/raise ToolSchemaParseError here for retry logic yield function_schema.parse_args( tool_parser.iter(self._stream, transition) @@ -414,19 +411,10 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: if content_parser.is_member(transition_item): yield AsyncStreamedStr(content_parser.aiter(self._stream, transition)) elif tool_parser.is_member(transition_item): - # TODO: Tidy matching function schema. Include Mistral fix - # tool_parser.select_function_schema() ? - function_schema = next( - ( - function_schema - for function_schema in self._function_schemas - if function_schema.name == transition_item.name - ), - None, + tool_name = tool_parser.get_tool_name(transition_item) + function_schema = select_function_schema( + self._function_schemas, tool_name ) - # TODO: Catch/raise unknown tool call error here - assert function_schema is not None # noqa: S101 - # TODO: Catch/raise ToolSchemaParseError here for retry logic yield await function_schema.aparse_args( tool_parser.aiter(self._stream, transition) ) From 360211eda624ea2794085db458b214103750a66c Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:18:46 -0800 Subject: [PATCH 09/40] Copy Stream classes into stream.py --- src/magentic/chat_model/stream.py | 161 ++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 src/magentic/chat_model/stream.py diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py new file mode 100644 index 00000000..e56f7415 --- /dev/null +++ b/src/magentic/chat_model/stream.py @@ -0,0 +1,161 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator +from itertools import chain +from typing import TYPE_CHECKING, Generic, TypeVar + +from magentic.chat_model.function_schema import FunctionSchema, select_function_schema +from magentic.streaming import AsyncStreamedStr, StreamedStr, achain, async_iter + +if TYPE_CHECKING: + from magentic.chat_model.message import Usage + + +T = TypeVar("T") +ItemT = TypeVar("ItemT") +OutputT = TypeVar("OutputT") + + +class StreamParser(ABC, Generic[ItemT, OutputT]): + """Filters and transforms items from an iterator until the end condition is met.""" + + def is_member(self, item: ItemT) -> bool: + return True + + @abstractmethod + def is_end(self, item: ItemT) -> bool: ... + + @abstractmethod + def transform(self, item: ItemT) -> OutputT: ... + + def iter( + self, iterator: Iterator[ItemT], transition: list[ItemT] + ) -> Iterator[OutputT]: + for item in iterator: + if self.is_member(item): + yield self.transform(item) + if self.is_end(item): + assert not transition # noqa: S101 + transition.append(item) + return + + async def aiter( + self, aiterator: AsyncIterator[ItemT], transition: list[ItemT] + ) -> AsyncIterator[OutputT]: + async for item in aiterator: + if self.is_member(item): + yield self.transform(item) + if self.is_end(item): + assert not transition # noqa: S101 + transition.append(item) + return + + +class OutputStream(Generic[T]): + """Converts streamed LLM output into a stream of magentic objects.""" + + def __init__( + self, + stream: Iterator, # TODO: Fix typing + function_schemas: list[FunctionSchema[T]], + content_parser: StreamParser, + tool_parser: StreamParser, + usage_parser: StreamParser, + ): + self._stream = stream + self._function_schemas = function_schemas + self._iterator = self.__stream__() + + self._content_parser = content_parser + self._tool_parser = tool_parser + self._usage_parser = usage_parser + + self.usage: Usage | None = None + + def __next__(self) -> StreamedStr | T: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[StreamedStr | T]: + yield from self._iterator + + def __stream__(self) -> Iterator[StreamedStr | T]: + transition = [next(self._stream)] + while transition: + transition_item = transition.pop() + stream_with_transition = chain([transition_item], self._stream) + if self._content_parser.is_member(transition_item): + yield StreamedStr( + self._content_parser.iter(stream_with_transition, transition) + ) + elif self._tool_parser.is_member(transition_item): + # TODO: Add new base class for tool parser + tool_name = self._tool_parser.get_tool_name(transition_item) + function_schema = select_function_schema( + self._function_schemas, tool_name + ) + # TODO: Catch/raise ToolSchemaParseError here for retry logic + yield function_schema.parse_args( + self._tool_parser.iter(stream_with_transition, transition) + ) + elif self._usage_parser.is_member(transition_item): + self.usage = self._usage_parser.transform(transition_item) + elif new_transition_item := next(self._stream, None): + transition.append(new_transition_item) + + def close(self): + self._stream.close() + + +class AsyncOutputStream(Generic[T]): + """Async version of `OutputStream`.""" + + def __init__( + self, + stream: AsyncIterator, # TODO: Fix typing + function_schemas: list[FunctionSchema[T]], + content_parser: StreamParser, + tool_parser: StreamParser, + usage_parser: StreamParser, + ): + self._stream = stream + self._function_schemas = function_schemas + self._iterator = self.__stream__() + + self._content_parser = content_parser + self._tool_parser = tool_parser + self._usage_parser = usage_parser + + self.usage: Usage | None = None + + async def __anext__(self) -> AsyncStreamedStr | T: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | T]: + async for item in self._iterator: + yield item + + async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: + transition = [await anext(self._stream)] + while transition: + transition_item = transition.pop() + stream_with_transition = achain(async_iter([transition_item]), self._stream) + if self._content_parser.is_member(transition_item): + yield AsyncStreamedStr( + self._content_parser.aiter(stream_with_transition, transition) + ) + elif self._tool_parser.is_member(transition_item): + # TODO: Add new base class for tool parser + tool_name = self._tool_parser.get_tool_name(transition_item) + function_schema = select_function_schema( + self._function_schemas, tool_name + ) + # TODO: Catch/raise ToolSchemaParseError here for retry logic + yield await function_schema.aparse_args( + self._tool_parser.aiter(stream_with_transition, transition) + ) + elif self._usage_parser.is_member(transition_item): + self.usage = self._usage_parser.transform(transition_item) + elif new_transition_item := await anext(self._stream, None): + transition.append(new_transition_item) + + async def close(self): + await self._stream.close() From 6fa7cdd720cd056f25361ecf427c5cae41f77568 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:20:20 -0800 Subject: [PATCH 10/40] Switch LitellmChatModel to use stream parsers --- src/magentic/chat_model/litellm_chat_model.py | 191 ++++++++---------- 1 file changed, 82 insertions(+), 109 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 827c28e3..ad80e3af 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -1,14 +1,12 @@ from collections.abc import Callable, Iterable, Sequence -from functools import wraps -from itertools import chain -from typing import Any, ParamSpec, TypeVar, cast, overload +from typing import Any, Literal, TypeVar, cast, overload -from openai.types.chat import ChatCompletionToolChoiceOptionParam +from litellm.litellm_core_utils.streaming_handler import StreamingChoices from magentic.chat_model.base import ( ChatModel, - avalidate_str_content, - validate_str_content, + aparse_stream, + parse_stream, ) from magentic.chat_model.function_schema import ( FunctionCallFunctionSchema, @@ -18,25 +16,17 @@ from magentic.chat_model.message import ( AssistantMessage, Message, + Usage, ) from magentic.chat_model.openai_chat_model import ( STR_OR_FUNCTIONCALL_TYPE, - AsyncFunctionToolSchema, BaseFunctionToolSchema, - FunctionToolSchema, - _aparse_streamed_tool_calls, - _parse_streamed_tool_calls, message_to_openai_message, ) -from magentic.function_call import ( - AsyncParallelFunctionCall, - ParallelFunctionCall, -) +from magentic.chat_model.stream import AsyncOutputStream, OutputStream, StreamParser from magentic.streaming import ( AsyncStreamedStr, StreamedStr, - achain, - async_iter, ) from magentic.typing import is_any_origin_subclass, is_origin_subclass @@ -48,21 +38,63 @@ raise ImportError(msg) from error -P = ParamSpec("P") -R = TypeVar("R") +class LitellmContentStreamParser(StreamParser[ModelResponse, str]): + """Filters and transforms LiteLLM content chunks from a stream.""" + + def is_member(self, item: ModelResponse) -> bool: + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + return bool(item.choices[0].delta.content) + + def is_end(self, item: ModelResponse) -> bool: + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + return bool(item.choices[0].delta.content is None) + + def transform(self, item: ModelResponse) -> str: + assert self.is_member(item) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + return item.choices[0].delta.content or "" + + +class LitellmToolStreamParser(StreamParser[ModelResponse, str]): + """Filters and transforms LiteLLM tool chunks from a stream.""" + + def is_member(self, item: ModelResponse) -> bool: + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + return bool(item.choices[0].delta.tool_calls is not None) + + def is_end(self, item: ModelResponse) -> bool: + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + return item.choices[0].delta.tool_calls is None + + def transform(self, item: ModelResponse) -> str: + assert self.is_member(item) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert item.choices[0].delta.tool_calls is not None # noqa: S101 + return item.choices[0].delta.tool_calls[0].function.arguments + def get_tool_name(self, item: ModelResponse) -> str: + assert self.is_member(item) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert item.choices[0].delta.tool_calls is not None # noqa: S101 + assert item.choices[0].delta.tool_calls[0].function.name # noqa: S101 + return item.choices[0].delta.tool_calls[0].function.name -def discard_none_arguments(func: Callable[P, R]) -> Callable[P, R]: - """Decorator to discard function arguments with value `None`""" - @wraps(func) - def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: - non_none_kwargs = { - key: value for key, value in kwargs.items() if value is not None - } - return func(*args, **non_none_kwargs) # type: ignore[arg-type] +# TODO: Implement LitellmToolStreamParser +class LitellmUsageStreamParser(StreamParser[ModelResponse, Usage]): + """Filters and transforms LiteLLM tool chunks from a stream.""" - return wrapped + def is_member(self, item: ModelResponse) -> bool: + return False + + def is_end(self, item: ModelResponse) -> bool: + return True + + def transform(self, item: ModelResponse) -> Usage: + return Usage(input_tokens=0, output_tokens=0) + + +R = TypeVar("R") class LitellmChatModel(ChatModel): @@ -114,12 +146,12 @@ def _get_tool_choice( *, tool_schemas: Sequence[BaseFunctionToolSchema[Any]], allow_string_output: bool, - ) -> ChatCompletionToolChoiceOptionParam | None: + ) -> dict | Literal["none", "auto", "required"] | None: """Create the tool choice argument.""" if allow_string_output: return None if len(tool_schemas) == 1: - return tool_schemas[0].as_tool_choice() + return tool_schemas[0].as_tool_choice() # type: ignore[return-value] return "required" @overload @@ -159,13 +191,13 @@ def complete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [FunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) allow_string_output = str_in_output_types or streamed_str_in_output_types - response = discard_none_arguments(litellm.completion)( + response = litellm.completion( model=self.model, messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, @@ -181,44 +213,14 @@ def complete( ), ) assert not isinstance(response, ModelResponse) # noqa: S101 - - first_chunk = next(response) - # Azure OpenAI sends a chunk with empty choices first - if len(first_chunk.choices) == 0: - first_chunk = next(response) - if ( - first_chunk.choices[0].delta.content is None - and first_chunk.choices[0].delta.tool_calls is None - ): - first_chunk = next(response) - response = chain([first_chunk], response) - - # Check tool calls before content because both might be present - if first_chunk.choices[0].delta.tool_calls is not None: - tool_calls = _parse_streamed_tool_calls(response, tool_schemas) - if is_any_origin_subclass(output_types, ParallelFunctionCall): - content = ParallelFunctionCall(tool_calls) - return AssistantMessage(content) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - # TODO: Create generator here that raises error or warns if multiple tool_calls - content = next(tool_calls) - return AssistantMessage(content) # type: ignore[return-value] - - if first_chunk.choices[0].delta.content is not None: - streamed_str = StreamedStr( - chunk.choices[0].delta.get("content", None) - for chunk in response - if chunk.choices[0].delta.get("content", None) is not None - ) - str_content = validate_str_content( - streamed_str, - allow_string_output=allow_string_output, - streamed=streamed_str_in_output_types, - ) - return AssistantMessage(str_content) # type: ignore[return-value] - - msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" - raise ValueError(msg) + stream = OutputStream( + stream=response, + function_schemas=function_schemas, + content_parser=LitellmContentStreamParser(), + tool_parser=LitellmToolStreamParser(), + usage_parser=LitellmUsageStreamParser(), + ) + return AssistantMessage(parse_stream(stream, output_types)) # type: ignore[return-value] @overload async def acomplete( @@ -257,7 +259,7 @@ async def acomplete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [AsyncFunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) async_streamed_str_in_output_types = is_any_origin_subclass( @@ -265,7 +267,7 @@ async def acomplete( ) allow_string_output = str_in_output_types or async_streamed_str_in_output_types - response = await discard_none_arguments(litellm.acompletion)( + response = await litellm.acompletion( model=self.model, messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, @@ -278,43 +280,14 @@ async def acomplete( tools=[schema.to_dict() for schema in tool_schemas] or None, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), + ), # type: ignore[arg-type] ) assert not isinstance(response, ModelResponse) # noqa: S101 - - first_chunk = await anext(response) - # Azure OpenAI sends a chunk with empty choices first - if len(first_chunk.choices) == 0: - first_chunk = await anext(response) - if ( - first_chunk.choices[0].delta.content is None - and first_chunk.choices[0].delta.tool_calls is None - ): - first_chunk = await anext(response) - response = achain(async_iter([first_chunk]), response) - - # Check tool calls before content because both might be present - if first_chunk.choices[0].delta.tool_calls is not None: - tool_calls = _aparse_streamed_tool_calls(response, tool_schemas) - if is_any_origin_subclass(output_types, AsyncParallelFunctionCall): - content = AsyncParallelFunctionCall(tool_calls) - return AssistantMessage(content) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - content = await anext(tool_calls) - return AssistantMessage(content) # type: ignore[return-value] - - if first_chunk.choices[0].delta.content is not None: - async_streamed_str = AsyncStreamedStr( - chunk.choices[0].delta.get("content", None) - async for chunk in response - if chunk.choices[0].delta.get("content", None) is not None - ) - str_content = await avalidate_str_content( - async_streamed_str, - allow_string_output=allow_string_output, - streamed=async_streamed_str_in_output_types, - ) - return AssistantMessage(str_content) # type: ignore[return-value] - - msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" - raise ValueError(msg) + stream = AsyncOutputStream( + stream=response, + function_schemas=function_schemas, + content_parser=LitellmContentStreamParser(), + tool_parser=LitellmToolStreamParser(), + usage_parser=LitellmUsageStreamParser(), + ) + return AssistantMessage(aparse_stream(stream, output_types)) # type: ignore[return-value] From f75853696859e2ff02818d5c228e90f2df48ce59 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:29:41 -0800 Subject: [PATCH 11/40] Switch OpenaiChatModel to use shared streaming classes --- src/magentic/chat_model/openai_chat_model.py | 160 +++---------------- 1 file changed, 18 insertions(+), 142 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index d31fb03c..c62e2c17 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,11 +1,9 @@ import base64 -from abc import ABC, abstractmethod from collections.abc import ( AsyncIterator, Awaitable, Callable, Iterable, - Iterator, Sequence, ) from enum import Enum @@ -15,9 +13,7 @@ import filetype import openai from openai.lib.streaming.chat import ( - AsyncChatCompletionStream, AsyncChatCompletionStreamManager, - ChatCompletionStream, ChatCompletionStreamEvent, ChunkEvent, ContentDeltaEvent, @@ -39,13 +35,11 @@ parse_stream, ) from magentic.chat_model.function_schema import ( - AsyncFunctionSchema, BaseFunctionSchema, FunctionCallFunctionSchema, FunctionSchema, async_function_schema_for_type, function_schema_for_type, - select_function_schema, ) from magentic.chat_model.message import ( AssistantMessage, @@ -56,6 +50,7 @@ UserMessage, _RawMessage, ) +from magentic.chat_model.stream import AsyncOutputStream, OutputStream, StreamParser from magentic.function_call import ( AsyncParallelFunctionCall, FunctionCall, @@ -240,45 +235,6 @@ def to_dict(self) -> ChatCompletionToolParam: return {"type": "function", "function": self._function_schema.dict()} -ItemT = TypeVar("ItemT") -OutputT = TypeVar("OutputT") - - -class StreamParser(ABC, Generic[ItemT, OutputT]): - """Filters and transforms items from an iterator until the end condition is met.""" - - def is_member(self, item: ItemT) -> bool: - return True - - @abstractmethod - def is_end(self, item: ItemT) -> bool: ... - - @abstractmethod - def transform(self, item: ItemT) -> OutputT: ... - - def iter( - self, iterator: Iterator[ItemT], transition: list[ItemT] - ) -> Iterator[OutputT]: - for item in iterator: - if self.is_member(item): - yield self.transform(item) - if self.is_end(item): - assert not transition # noqa: S101 - transition.append(item) - return - - async def aiter( - self, aiterator: AsyncIterator[ItemT], transition: list[ItemT] - ) -> AsyncIterator[OutputT]: - async for item in aiterator: - if self.is_member(item): - yield self.transform(item) - if self.is_end(item): - assert not transition # noqa: S101 - transition.append(item) - return - - class OpenaiContentStreamParser(StreamParser[ChatCompletionStreamEvent, str]): """Filters and transforms OpenAI content events from a stream.""" @@ -313,7 +269,7 @@ def transform(self, item: ChatCompletionStreamEvent) -> str: return item.arguments_delta def get_tool_name(self, item: ChatCompletionStreamEvent) -> str: - assert self.is_member(item) + assert self.is_member(item) # noqa: S101 return item.name @@ -335,98 +291,6 @@ def transform(self, item: ChatCompletionStreamEvent) -> Usage: ) -class OpenaiStream(Generic[T]): - """Converts a stream of openai events into a stream of magentic objects.""" - - def __init__( - self, stream: ChatCompletionStream, function_schemas: list[FunctionSchema[T]] - ): - self._stream = stream - self._function_schemas = function_schemas - self._iterator = self.__stream__() - self.usage: Usage | None = None - - def __next__(self) -> StreamedStr | T: - return self._iterator.__next__() - - def __iter__(self) -> Iterator[StreamedStr | T]: - yield from self._iterator - - def __stream__(self) -> Iterator[StreamedStr | T]: - transition = [next(self._stream)] - content_parser = OpenaiContentStreamParser() - tool_parser = OpenaiToolStreamParser() - usage_parser = OpenaiUsageStreamParser() - - while transition: - transition_item = transition.pop() - if content_parser.is_member(transition_item): - yield StreamedStr(content_parser.iter(self._stream, transition)) - elif tool_parser.is_member(transition_item): - tool_name = tool_parser.get_tool_name(transition_item) - function_schema = select_function_schema( - self._function_schemas, tool_name - ) - # TODO: Catch/raise ToolSchemaParseError here for retry logic - yield function_schema.parse_args( - tool_parser.iter(self._stream, transition) - ) - elif usage_parser.is_member(transition_item): - self.usage = usage_parser.transform(transition_item) - elif new_transition_item := next(self._stream, None): - transition.append(new_transition_item) - - def close(self): - self._stream.close() - - -class OpenaiAsyncStream(Generic[T]): - """Converts an async stream of openai events into an async stream of magentic objects.""" - - def __init__( - self, - stream: AsyncChatCompletionStream, - function_schemas: list[AsyncFunctionSchema[T]], - ): - self._stream = stream - self._function_schemas = function_schemas - self._aiterator = self.__stream__() - self.usage: Usage | None = None - - async def __anext__(self) -> AsyncStreamedStr | T: - return await self._aiterator.__anext__() - - async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | T]: - async for item in self._aiterator: - yield item - - async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: - transition = [await anext(self._stream)] - content_parser = OpenaiContentStreamParser() - tool_parser = OpenaiToolStreamParser() - usage_parser = OpenaiUsageStreamParser() - - while transition: - transition_item = transition.pop() - if content_parser.is_member(transition_item): - yield AsyncStreamedStr(content_parser.aiter(self._stream, transition)) - elif tool_parser.is_member(transition_item): - tool_name = tool_parser.get_tool_name(transition_item) - function_schema = select_function_schema( - self._function_schemas, tool_name - ) - yield await function_schema.aparse_args( - tool_parser.aiter(self._stream, transition) - ) - elif usage_parser.is_member(transition_item): - self.usage = usage_parser.transform(transition_item) - elif new_transition_item := await anext(self._stream, None): - transition.append(new_transition_item) - - async def close(self): - await self._stream.close() - - def _if_given(value: T | None) -> T | openai.NotGiven: return value if value is not None else openai.NOT_GIVEN @@ -602,8 +466,14 @@ def complete( tools_specified=bool(tool_schemas), output_types=output_types ), ).__enter__() # Get stream directly, without context manager - stream = OpenaiStream(_stream, function_schemas=function_schemas) - return AssistantMessage(parse_stream(stream, output_types)) # type: ignore + stream = OutputStream( + _stream, + function_schemas=function_schemas, + content_parser=OpenaiContentStreamParser(), + tool_parser=OpenaiToolStreamParser(), + usage_parser=OpenaiUsageStreamParser(), + ) + return AssistantMessage(parse_stream(stream, output_types)) # type: ignore[return-type] @overload async def acomplete( @@ -677,5 +547,11 @@ async def acomplete( input_tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, ).__aenter__() # Get stream directly, without context manager - stream = OpenaiAsyncStream(_stream, function_schemas=function_schemas) - return AssistantMessage(aparse_stream(stream, output_types)) # type: ignore + stream = AsyncOutputStream( + _stream, + function_schemas=function_schemas, + content_parser=OpenaiContentStreamParser(), + tool_parser=OpenaiToolStreamParser(), + usage_parser=OpenaiUsageStreamParser(), + ) + return AssistantMessage(await aparse_stream(stream, output_types)) # type: ignore[return-type] From 45074a59f90ca5ac5eee06470b66a6cd737e67eb Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:37:12 -0800 Subject: [PATCH 12/40] Add docstring for is_instance_origin --- src/magentic/typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/magentic/typing.py b/src/magentic/typing.py index c14400fa..efec30c7 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -43,6 +43,7 @@ def is_origin_subclass( def is_instance_origin( obj: Any, cls_or_tuple: type[T] | tuple[type[T], ...] ) -> TypeGuard[T]: + """Check if the object is an instance of the origin(s) of the given type(s).""" cls_or_tuple_origin = ( tuple(get_origin(cls) or cls for cls in cls_or_tuple) if isinstance(cls_or_tuple, tuple) From 4f37190159ccddcda6e84eec66c105a3558a70bc Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 20:37:27 -0800 Subject: [PATCH 13/40] Use type origins in parse_stream --- src/magentic/chat_model/base.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 94f955b3..48f795c8 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Iterable from contextvars import ContextVar from itertools import chain -from typing import Any, AsyncIterator, Iterator, TypeVar, cast, overload +from typing import Any, AsyncIterator, Iterator, TypeVar, cast, get_origin, overload from pydantic import ValidationError @@ -14,7 +14,6 @@ ParallelFunctionCall, ) from magentic.streaming import AsyncStreamedStr, StreamedStr, achain, async_iter -from magentic.typing import is_instance_origin R = TypeVar("R") @@ -86,46 +85,48 @@ async def avalidate_str_content( # TODO: Make this a stream class with a close method and context management def parse_stream(stream: Iterator[Any], output_types: list[type[R]]) -> R: """Parse and validate the LLM output stream against the allowed output types.""" + output_type_origins = [get_origin(type_) or type_ for type_ in output_types] # TODO: option to error/warn/ignore extra objects # TODO: warn for degenerate output types ? obj = next(stream) # TODO: Add type for mixed StreamedStr and FunctionCalls if isinstance(obj, StreamedStr): - if StreamedStr in output_types: + if StreamedStr in output_type_origins: return cast(R, obj) - if str in output_types: + if str in output_type_origins: return cast(R, str(obj)) model_output = obj.truncate(100) raise StringNotAllowedError(AssistantMessage(model_output)) if isinstance(obj, FunctionCall): - if ParallelFunctionCall in output_types: + if ParallelFunctionCall in output_type_origins: return cast(R, ParallelFunctionCall(chain([obj], stream))) - if FunctionCall in output_types: + if FunctionCall in output_type_origins: # TODO: Check that FunctionCall type matches ? return cast(R, obj) raise ValueError("FunctionCall not allowed") - if is_instance_origin(obj, tuple(output_types)): + if isinstance(obj, tuple(output_type_origins)): return obj raise ValueError(f"Unexpected output type: {type(obj)}") async def aparse_stream(stream: AsyncIterator[Any], output_types: list[type[R]]) -> R: """Async version of `parse_stream`.""" + output_type_origins = [get_origin(type_) or type_ for type_ in output_types] obj = await anext(stream) if isinstance(obj, AsyncStreamedStr): - if AsyncStreamedStr in output_types: + if AsyncStreamedStr in output_type_origins: return cast(R, obj) - if str in output_types: + if str in output_type_origins: return cast(R, await obj.to_string()) model_output = await obj.truncate(100) raise StringNotAllowedError(AssistantMessage(model_output)) if isinstance(obj, FunctionCall): - if AsyncParallelFunctionCall in output_types: + if AsyncParallelFunctionCall in output_type_origins: return cast(R, AsyncParallelFunctionCall(achain(async_iter([obj]), stream))) - if FunctionCall in output_types: + if FunctionCall in output_type_origins: return cast(R, obj) raise ValueError("FunctionCall not allowed") - if is_instance_origin(obj, tuple(output_types)): + if isinstance(obj, tuple(output_type_origins)): return obj raise ValueError(f"Unexpected output type: {type(obj)}") From 45b879919a5df83afd1a0ff37497b2366a401b9d Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 21:04:23 -0800 Subject: [PATCH 14/40] Fix complete type hints for LitellmChatModel --- src/magentic/chat_model/base.py | 7 +++++-- src/magentic/chat_model/litellm_chat_model.py | 8 ++++---- src/magentic/chat_model/openai_chat_model.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 48f795c8..5dfa5d8a 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -83,7 +83,7 @@ async def avalidate_str_content( # TODO: Make this a stream class with a close method and context management -def parse_stream(stream: Iterator[Any], output_types: list[type[R]]) -> R: +def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R: """Parse and validate the LLM output stream against the allowed output types.""" output_type_origins = [get_origin(type_) or type_ for type_ in output_types] # TODO: option to error/warn/ignore extra objects @@ -109,7 +109,9 @@ def parse_stream(stream: Iterator[Any], output_types: list[type[R]]) -> R: raise ValueError(f"Unexpected output type: {type(obj)}") -async def aparse_stream(stream: AsyncIterator[Any], output_types: list[type[R]]) -> R: +async def aparse_stream( + stream: AsyncIterator[Any], output_types: Iterable[type[R]] +) -> R: """Async version of `parse_stream`.""" output_type_origins = [get_origin(type_) or type_ for type_ in output_types] obj = await anext(stream) @@ -161,6 +163,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., Any]] | None = None, + # TODO: Set default of R to str in Python 3.13 output_types: Iterable[type[R | str]] | None = None, *, stop: list[str] | None = None, diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index ad80e3af..137ec3ae 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -184,7 +184,7 @@ def complete( ) -> AssistantMessage[str] | AssistantMessage[R]: """Request an LLM message.""" if output_types is None: - output_types = [] if functions else cast(list[type[R]], [str]) + output_types = cast(Iterable[type[R]], [] if functions else [str]) function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ function_schema_for_type(type_) @@ -220,7 +220,7 @@ def complete( tool_parser=LitellmToolStreamParser(), usage_parser=LitellmUsageStreamParser(), ) - return AssistantMessage(parse_stream(stream, output_types)) # type: ignore[return-value] + return AssistantMessage(parse_stream(stream, output_types)) @overload async def acomplete( @@ -252,7 +252,7 @@ async def acomplete( ) -> AssistantMessage[str] | AssistantMessage[R]: """Async version of `complete`.""" if output_types is None: - output_types = [] if functions else cast(list[type[R]], [str]) + output_types = cast(Iterable[type[R]], [] if functions else [str]) function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ async_function_schema_for_type(type_) @@ -290,4 +290,4 @@ async def acomplete( tool_parser=LitellmToolStreamParser(), usage_parser=LitellmUsageStreamParser(), ) - return AssistantMessage(aparse_stream(stream, output_types)) # type: ignore[return-value] + return AssistantMessage(await aparse_stream(stream, output_types)) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index c62e2c17..ae57cf6f 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -434,7 +434,7 @@ def complete( ) -> AssistantMessage[str] | AssistantMessage[R]: """Request an LLM message.""" if output_types is None: - output_types = [] if functions else cast(list[type[R]], [str]) + output_types = cast(Iterable[type[R]], [] if functions else [str]) # TODO: Check that Function calls types match functions function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ From f8b92a4b87722688e05576890efda0b6c6ec8799 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 23 Nov 2024 21:19:29 -0800 Subject: [PATCH 15/40] Change function_schemas type list -> Iterable --- src/magentic/chat_model/stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index e56f7415..3993f2ef 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterable, Iterator from itertools import chain from typing import TYPE_CHECKING, Generic, TypeVar @@ -56,7 +56,7 @@ class OutputStream(Generic[T]): def __init__( self, stream: Iterator, # TODO: Fix typing - function_schemas: list[FunctionSchema[T]], + function_schemas: Iterable[FunctionSchema[T]], content_parser: StreamParser, tool_parser: StreamParser, usage_parser: StreamParser, @@ -111,7 +111,7 @@ class AsyncOutputStream(Generic[T]): def __init__( self, stream: AsyncIterator, # TODO: Fix typing - function_schemas: list[FunctionSchema[T]], + function_schemas: Iterable[FunctionSchema[T]], content_parser: StreamParser, tool_parser: StreamParser, usage_parser: StreamParser, From 6b464dd76de396dcb5eb55581aa5e62ed23fc6e3 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:16:01 -0800 Subject: [PATCH 16/40] Add StreamState and OpenaiStreamState --- src/magentic/chat_model/openai_chat_model.py | 51 ++++++++- src/magentic/chat_model/stream.py | 107 ++++++++++++++----- 2 files changed, 129 insertions(+), 29 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index ae57cf6f..eb4f2b5f 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -21,6 +21,7 @@ FunctionToolCallArgumentsDeltaEvent, FunctionToolCallArgumentsDoneEvent, ) +from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.types.chat import ( ChatCompletionChunk, ChatCompletionMessageParam, @@ -50,7 +51,12 @@ UserMessage, _RawMessage, ) -from magentic.chat_model.stream import AsyncOutputStream, OutputStream, StreamParser +from magentic.chat_model.stream import ( + AsyncOutputStream, + OutputStream, + StreamParser, + StreamState, +) from magentic.function_call import ( AsyncParallelFunctionCall, FunctionCall, @@ -273,6 +279,7 @@ def get_tool_name(self, item: ChatCompletionStreamEvent) -> str: return item.name +# TODO: Move usage tracking into OpenaiStreamState class OpenaiUsageStreamParser(StreamParser[ChatCompletionStreamEvent, Usage]): """Filters and transforms OpenAI usage events from a stream.""" @@ -291,6 +298,45 @@ def transform(self, item: ChatCompletionStreamEvent) -> Usage: ) +class OpenaiStreamState(StreamState): + def __init__(self, function_schemas): + self._function_schemas = function_schemas + + self._chat_completion_stream_state = ChatCompletionStreamState( + input_tools=openai.NOT_GIVEN, + response_format=openai.NOT_GIVEN, + ) + self._current_tool_call_id: str | None = None + + def update(self, item: ChatCompletionStreamEvent) -> None: + if item.type == "chunk": + self._chat_completion_stream_state.handle_chunk(item.chunk) + if ( + item.type == "chunk" + and item.chunk.choices + and item.chunk.choices[0].delta.tool_calls + and item.chunk.choices[0].delta.tool_calls[0].id + ): + # TODO: Mistral fix here ? + # openai keeps index consistent for chunks from the same tool_call, but id is null + # mistral has null index, but keeps id consistent + self._current_tool_call_id = item.chunk.choices[0].delta.tool_calls[0].id + + @property + def current_tool_call_id(self) -> str | None: + return self._current_tool_call_id + + @property + def current_message_snapshot(self) -> Message: + message = ( + self._chat_completion_stream_state.current_completion_snapshot.choices[ + 0 + ].message + ) + # TODO: Possible to return AssistantMessage here? + return _RawMessage(message.model_dump()) + + def _if_given(value: T | None) -> T | openai.NotGiven: return value if value is not None else openai.NOT_GIVEN @@ -469,9 +515,11 @@ def complete( stream = OutputStream( _stream, function_schemas=function_schemas, + # TODO: Consoldate these into a single parser / state object? content_parser=OpenaiContentStreamParser(), tool_parser=OpenaiToolStreamParser(), usage_parser=OpenaiUsageStreamParser(), + state=OpenaiStreamState(function_schemas=function_schemas), ) return AssistantMessage(parse_stream(stream, output_types)) # type: ignore[return-type] @@ -553,5 +601,6 @@ async def acomplete( content_parser=OpenaiContentStreamParser(), tool_parser=OpenaiToolStreamParser(), usage_parser=OpenaiUsageStreamParser(), + state=OpenaiStreamState(function_schemas=function_schemas), ) return AssistantMessage(await aparse_stream(stream, output_types)) # type: ignore[return-type] diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 3993f2ef..ca07ed1b 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -3,14 +3,25 @@ from itertools import chain from typing import TYPE_CHECKING, Generic, TypeVar +from litellm.llms.files_apis.azure import Any +from pydantic import ValidationError + +from magentic.chat_model.base import ToolSchemaParseError from magentic.chat_model.function_schema import FunctionSchema, select_function_schema -from magentic.streaming import AsyncStreamedStr, StreamedStr, achain, async_iter +from magentic.chat_model.message import Message +from magentic.streaming import ( + AsyncStreamedStr, + StreamedStr, + aapply, + achain, + apply, + async_iter, +) if TYPE_CHECKING: from magentic.chat_model.message import Usage -T = TypeVar("T") ItemT = TypeVar("ItemT") OutputT = TypeVar("OutputT") @@ -50,16 +61,30 @@ async def aiter( return -class OutputStream(Generic[T]): +class StreamState(ABC, Generic[ItemT]): + @abstractmethod + def update(self, item: ItemT) -> None: ... + + @property + @abstractmethod + def current_tool_call_id(self) -> str | None: ... + + @property + @abstractmethod + def current_message_snapshot(self) -> Message[Any]: ... + + +class OutputStream(Generic[ItemT, OutputT]): """Converts streamed LLM output into a stream of magentic objects.""" def __init__( self, - stream: Iterator, # TODO: Fix typing - function_schemas: Iterable[FunctionSchema[T]], + stream: Iterator[ItemT], + function_schemas: Iterable[FunctionSchema[OutputT]], content_parser: StreamParser, tool_parser: StreamParser, usage_parser: StreamParser, + state: StreamState[ItemT], ): self._stream = stream self._function_schemas = function_schemas @@ -68,20 +93,23 @@ def __init__( self._content_parser = content_parser self._tool_parser = tool_parser self._usage_parser = usage_parser + self._state = state + + self._wrapped_stream = apply(self._state.update, stream) self.usage: Usage | None = None - def __next__(self) -> StreamedStr | T: + def __next__(self) -> StreamedStr | OutputT: return self._iterator.__next__() - def __iter__(self) -> Iterator[StreamedStr | T]: + def __iter__(self) -> Iterator[StreamedStr | OutputT]: yield from self._iterator - def __stream__(self) -> Iterator[StreamedStr | T]: - transition = [next(self._stream)] + def __stream__(self) -> Iterator[StreamedStr | OutputT]: + transition = [next(self._wrapped_stream)] while transition: transition_item = transition.pop() - stream_with_transition = chain([transition_item], self._stream) + stream_with_transition = chain([transition_item], self._wrapped_stream) if self._content_parser.is_member(transition_item): yield StreamedStr( self._content_parser.iter(stream_with_transition, transition) @@ -92,29 +120,39 @@ def __stream__(self) -> Iterator[StreamedStr | T]: function_schema = select_function_schema( self._function_schemas, tool_name ) - # TODO: Catch/raise ToolSchemaParseError here for retry logic - yield function_schema.parse_args( - self._tool_parser.iter(stream_with_transition, transition) - ) + try: + yield function_schema.parse_args( + self._tool_parser.iter(stream_with_transition, transition) + ) + # TODO: Catch/raise unknown tool call error here + except ValidationError as e: + assert self._state.current_tool_call_id is not None # noqa: S101 + raise ToolSchemaParseError( + output_message=self._state.current_message_snapshot, + tool_call_id=self._state.current_tool_call_id, + validation_error=e, + ) from e + # TODO: Move usage tracking into StreamState elif self._usage_parser.is_member(transition_item): self.usage = self._usage_parser.transform(transition_item) - elif new_transition_item := next(self._stream, None): + elif new_transition_item := next(self._wrapped_stream, None): transition.append(new_transition_item) def close(self): self._stream.close() -class AsyncOutputStream(Generic[T]): +class AsyncOutputStream(Generic[ItemT, OutputT]): """Async version of `OutputStream`.""" def __init__( self, - stream: AsyncIterator, # TODO: Fix typing - function_schemas: Iterable[FunctionSchema[T]], + stream: AsyncIterator[ItemT], + function_schemas: Iterable[FunctionSchema[OutputT]], content_parser: StreamParser, tool_parser: StreamParser, usage_parser: StreamParser, + state: StreamState[ItemT], ): self._stream = stream self._function_schemas = function_schemas @@ -123,21 +161,26 @@ def __init__( self._content_parser = content_parser self._tool_parser = tool_parser self._usage_parser = usage_parser + self._state = state + + self._wrapped_stream = aapply(self._state.update, stream) self.usage: Usage | None = None - async def __anext__(self) -> AsyncStreamedStr | T: + async def __anext__(self) -> AsyncStreamedStr | OutputT: return await self._iterator.__anext__() - async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | T]: + async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: async for item in self._iterator: yield item - async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: - transition = [await anext(self._stream)] + async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: + transition = [await anext(self._wrapped_stream)] while transition: transition_item = transition.pop() - stream_with_transition = achain(async_iter([transition_item]), self._stream) + stream_with_transition = achain( + async_iter([transition_item]), self._wrapped_stream + ) if self._content_parser.is_member(transition_item): yield AsyncStreamedStr( self._content_parser.aiter(stream_with_transition, transition) @@ -148,13 +191,21 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | T]: function_schema = select_function_schema( self._function_schemas, tool_name ) - # TODO: Catch/raise ToolSchemaParseError here for retry logic - yield await function_schema.aparse_args( - self._tool_parser.aiter(stream_with_transition, transition) - ) + try: + yield await function_schema.aparse_args( + self._tool_parser.aiter(stream_with_transition, transition) + ) + # TODO: Catch/raise unknown tool call error here + except ValidationError as e: + assert self._state.current_tool_call_id is not None # noqa: S101 + raise ToolSchemaParseError( + output_message=self._state.current_message_snapshot, + tool_call_id=self._state.current_tool_call_id, + validation_error=e, + ) from e elif self._usage_parser.is_member(transition_item): self.usage = self._usage_parser.transform(transition_item) - elif new_transition_item := await anext(self._stream, None): + elif new_transition_item := await anext(self._wrapped_stream, None): transition.append(new_transition_item) async def close(self): From b073258ffb3b65104306f29982c2c9fd43fa7cff Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:17:57 -0800 Subject: [PATCH 17/40] Add TODOs --- src/magentic/chat_model/base.py | 2 ++ src/magentic/chat_model/openai_chat_model.py | 3 +++ tests/chat_model/test_openai_chat_model.py | 4 ++++ tests/test_prompt_chain.py | 2 ++ 4 files changed, 11 insertions(+) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 5dfa5d8a..8d153753 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -58,6 +58,7 @@ def __init__( self.validation_error = validation_error +# TODO: Delete this function def validate_str_content( streamed_str: StreamedStr, *, allow_string_output: bool, streamed: bool ) -> StreamedStr | str: @@ -70,6 +71,7 @@ def validate_str_content( return str(streamed_str) +# TODO: Delete this function async def avalidate_str_content( async_streamed_str: AsyncStreamedStr, *, allow_string_output: bool, streamed: bool ) -> AsyncStreamedStr | str: diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index eb4f2b5f..13baecaf 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -477,6 +477,7 @@ def complete( output_types: Iterable[type[R]] | None = None, *, stop: list[str] | None = None, + # TODO: Add type hint for function call ? ) -> AssistantMessage[str] | AssistantMessage[R]: """Request an LLM message.""" if output_types is None: @@ -490,10 +491,12 @@ def complete( ] tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] + # TODO: pass output_types to _get_tool_choice directly and remove these str_in_output_types = is_any_origin_subclass(output_types, str) streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) allow_string_output = str_in_output_types or streamed_str_in_output_types + # TODO: Switch to the create method to avoid possible validation addition _stream = self._client.beta.chat.completions.stream( model=self.model, messages=_add_missing_tool_calls_responses( diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index 26f06efe..51e33ce7 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -186,6 +186,7 @@ def plus(a: int, b: int) -> int: assert isinstance(message.content, FunctionCall) +@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai def test_openai_chat_model_complete_usage(): chat_model = OpenaiChatModel("gpt-4o") @@ -198,6 +199,7 @@ def test_openai_chat_model_complete_usage(): assert message.usage.output_tokens > 0 +@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai def test_openai_chat_model_complete_usage_structured_output(): chat_model = OpenaiChatModel("gpt-4o") @@ -238,6 +240,7 @@ class Test(BaseModel): ) +@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai async def test_openai_chat_model_acomplete_usage(): chat_model = OpenaiChatModel("gpt-4o") @@ -250,6 +253,7 @@ async def test_openai_chat_model_acomplete_usage(): assert message.usage.output_tokens > 0 +@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai async def test_openai_chat_model_acomplete_usage_structured_output(): chat_model = OpenaiChatModel("gpt-4o") diff --git a/tests/test_prompt_chain.py b/tests/test_prompt_chain.py index 63796f41..69d85665 100644 --- a/tests/test_prompt_chain.py +++ b/tests/test_prompt_chain.py @@ -7,6 +7,7 @@ from magentic.prompt_chain import MaxFunctionCallsError, prompt_chain +@pytest.mark.skip("TODO: Add FunctionCall to output_types internal to prompt_chain") @pytest.mark.openai def test_prompt_chain(): def get_current_weather(location, unit="fahrenheit"): @@ -50,6 +51,7 @@ def make_function_call() -> str: ... assert mock_function.call_count == 1 +@pytest.mark.skip("TODO: Add FunctionCall to output_types internal to prompt_chain") @pytest.mark.openai async def test_async_prompt_chain(): async def get_current_weather(location, unit="fahrenheit"): From f6f55c33f5f8ce7ad0344a29b21b4ef19477f043 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 24 Nov 2024 21:44:13 -0800 Subject: [PATCH 18/40] Update openai retry test cassettes --- .../test_async_decorator_max_retries.yaml | 80 +++++++++-------- .../test_decorator_max_retries.yaml | 86 ++++++++++--------- ...est_retry_chat_model_acomplete_openai.yaml | 76 ++++++++-------- ...test_retry_chat_model_complete_openai.yaml | 80 ++++++++--------- 4 files changed, 161 insertions(+), 161 deletions(-) diff --git a/tests/cassettes/test_prompt_function/test_async_decorator_max_retries.yaml b/tests/cassettes/test_prompt_function/test_async_decorator_max_retries.yaml index a20fa538..6776a4e1 100644 --- a/tests/cassettes/test_prompt_function/test_async_decorator_max_retries.yaml +++ b/tests/cassettes/test_prompt_function/test_async_decorator_max_retries.yaml @@ -41,28 +41,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_Uiz13O1iMw5tob2yIDXaRHDH","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_6Dk7MgfrU5OMIvVdcxiJ9TLE","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Japan"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Australia"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UbB4ld1146U4QhHhxP4F02h3eV","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5kjh54ArE89qKZKNsCik5EZIEt","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -73,13 +73,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367cdc792b2523-SJC + - 8e7f4af0edbf238d-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:53 GMT + - Mon, 25 Nov 2024 05:38:28 GMT Server: - cloudflare Transfer-Encoding: @@ -91,7 +91,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '231' + - '205' openai-version: - '2020-10-01' strict-transport-security: @@ -109,23 +109,24 @@ interactions: x-ratelimit-reset-tokens: - 44ms x-request-id: - - req_36b6885837304c95255dc0e243ddd0bc + - req_5df2a7cf8b9d982084c0c588bcce800d status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_Uiz13O1iMw5tob2yIDXaRHDH", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Japan\"}"}}]}, - {"role": "tool", "tool_call_id": "call_Uiz13O1iMw5tob2yIDXaRHDH", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Japan'', input_type=str]\n For further information - visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": "gpt-4o", "parallel_tool_calls": - false, "stream": true, "stream_options": {"include_usage": true}, "tool_choice": - {"type": "function", "function": {"name": "return_country"}}, "tools": [{"type": - "function", "function": {"name": "return_country", "parameters": {"properties": - {"name": {"title": "Name", "type": "string"}}, "required": ["name"], "type": - "object"}}}]}' + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_6Dk7MgfrU5OMIvVdcxiJ9TLE", "function": {"arguments": + "{\"name\":\"Australia\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_6Dk7MgfrU5OMIvVdcxiJ9TLE", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''Australia'', + input_type=str]\n For further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o", "parallel_tool_calls": false, "stream": true, "stream_options": + {"include_usage": true}, "tool_choice": {"type": "function", "function": {"name": + "return_country"}}, "tools": [{"type": "function", "function": {"name": "return_country", + "parameters": {"properties": {"name": {"title": "Name", "type": "string"}}, + "required": ["name"], "type": "object"}}}]}' headers: accept: - application/json @@ -134,12 +135,9 @@ interactions: connection: - keep-alive content-length: - - '929' + - '1046' content-type: - application/json - cookie: - - __cf_bm=wEsL6ULd.xp6gCoeJPmKaXwe3SOIEpJ2OD7N7AW5Uws-1731749693-1.0.1.1-ub_GAbKgRsGOITlv_88OsV4ivL8UZhOb2zb2IxLxtM4Z6zT3bROk5hV_1ZcPL.d.KcwDZg9Ns_qofiVGmrMiuQ; - _cfuvid=_7gaoOsQ2EivHbO2o7pDXxdovG6ZmrHdC3szn.mJVww-1731749693141-0.0.1.1-604800000 host: - api.openai.com user-agent: @@ -164,28 +162,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_ifXs1dcqtJy3GNJ0UpW7AR8C","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_9492M1JOz1OVPrONfWjISFxN","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UbpuvQHeG32Dr7WUawf4oshZpp","object":"chat.completion.chunk","created":1731749693,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5lf2YZmhdFaEv938fU8xSwn5ug","object":"chat.completion.chunk","created":1732513109,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -196,13 +194,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367cdf0aa72523-SJC + - 8e7f4af438da238d-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:53 GMT + - Mon, 25 Nov 2024 05:38:29 GMT Server: - cloudflare Transfer-Encoding: @@ -214,7 +212,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '209' + - '276' openai-version: - '2020-10-01' strict-transport-security: @@ -230,9 +228,9 @@ interactions: x-ratelimit-reset-requests: - 120ms x-ratelimit-reset-tokens: - - 152ms + - 154ms x-request-id: - - req_dc27be5835d770a6829f394a4c68c56c + - req_1c8ba63d5419f12f535bf10ab37d2ddc status: code: 200 message: OK diff --git a/tests/cassettes/test_prompt_function/test_decorator_max_retries.yaml b/tests/cassettes/test_prompt_function/test_decorator_max_retries.yaml index 025b6cbd..1dd495d4 100644 --- a/tests/cassettes/test_prompt_function/test_decorator_max_retries.yaml +++ b/tests/cassettes/test_prompt_function/test_decorator_max_retries.yaml @@ -25,6 +25,8 @@ interactions: - arm64 x-stainless-async: - 'false' + x-stainless-helper-method: + - beta.chat.completions.stream x-stainless-lang: - python x-stainless-os: @@ -41,28 +43,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_wdEg17NjgwLbqVdKjq6MKLJ1","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_Cf3PFZFWlA0He0ZOgneuubd1","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"country"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UZipZ8xDJl7GFgAuLyFkFGnJ9p","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_159d8341cc","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5jAijInCrqEmlZs6ylqF3hpKe9","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_a7d06e42a7","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -73,13 +75,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367cd11df226b0-SJC + - 8e7f4ae92cdafa62-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:51 GMT + - Mon, 25 Nov 2024 05:38:27 GMT Server: - cloudflare Transfer-Encoding: @@ -91,7 +93,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '238' + - '357' openai-version: - '2020-10-01' strict-transport-security: @@ -109,23 +111,24 @@ interactions: x-ratelimit-reset-tokens: - 44ms x-request-id: - - req_55c15dbef4c2f0a4026078c4d3051d04 + - req_ad4531da079aa84b23a02fcd6fcb5d45 status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_wdEg17NjgwLbqVdKjq6MKLJ1", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Canada\"}"}}]}, - {"role": "tool", "tool_call_id": "call_wdEg17NjgwLbqVdKjq6MKLJ1", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Canada'', input_type=str]\n For further - information visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": - "gpt-4o", "parallel_tool_calls": false, "stream": true, "stream_options": {"include_usage": - true}, "tool_choice": {"type": "function", "function": {"name": "return_country"}}, - "tools": [{"type": "function", "function": {"name": "return_country", "parameters": - {"properties": {"name": {"title": "Name", "type": "string"}}, "required": ["name"], - "type": "object"}}}]}' + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_Cf3PFZFWlA0He0ZOgneuubd1", "function": {"arguments": + "{\"name\":\"country\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_Cf3PFZFWlA0He0ZOgneuubd1", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''country'', + input_type=str]\n For further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o", "parallel_tool_calls": false, "stream": true, "stream_options": + {"include_usage": true}, "tool_choice": {"type": "function", "function": {"name": + "return_country"}}, "tools": [{"type": "function", "function": {"name": "return_country", + "parameters": {"properties": {"name": {"title": "Name", "type": "string"}}, + "required": ["name"], "type": "object"}}}]}' headers: accept: - application/json @@ -134,12 +137,9 @@ interactions: connection: - keep-alive content-length: - - '931' + - '1042' content-type: - application/json - cookie: - - __cf_bm=sPO4iMakhVSeN2qInDys83OtESskcGEQtUulfF25yU8-1731749691-1.0.1.1-bndGTvCeKjcqqi0MUQq4bEF3.atEdhQbF.8VGPbpjf1qoXcYhD5.v59eEOEpo6RpfegQYZGaUOPOT_fbY7Au1A; - _cfuvid=sCef7lr9GC3CdWBEwnTFFR7tRfAdsuJ0FQJjudsAOpQ-1731749691329-0.0.1.1-604800000 host: - api.openai.com user-agent: @@ -148,6 +148,8 @@ interactions: - arm64 x-stainless-async: - 'false' + x-stainless-helper-method: + - beta.chat.completions.stream x-stainless-lang: - python x-stainless-os: @@ -164,28 +166,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_cH1CgAuhZ0YHubyQGGfxl42Q","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_D3l38Wx0zyqPSyWngIhjMWVd","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UZl5mf807mUw8xigORAVUueic3","object":"chat.completion.chunk","created":1731749691,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_bb84311112","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5k8UrAU68NVOIrhTyvKl5egZ6g","object":"chat.completion.chunk","created":1732513108,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -196,13 +198,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367cd34ff326b0-SJC + - 8e7f4aed6859fa62-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:51 GMT + - Mon, 25 Nov 2024 05:38:28 GMT Server: - cloudflare Transfer-Encoding: @@ -214,7 +216,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '356' + - '301' openai-version: - '2020-10-01' strict-transport-security: @@ -226,13 +228,13 @@ interactions: x-ratelimit-remaining-requests: - '499' x-ratelimit-remaining-tokens: - - '29913' + - '29923' x-ratelimit-reset-requests: - 120ms x-ratelimit-reset-tokens: - - 172ms + - 154ms x-request-id: - - req_d1d6dffe6036d31153d701cfd53dead1 + - req_611308496b6eaa46484a1c2ccfe4dfa7 status: code: 200 message: OK diff --git a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_openai.yaml b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_openai.yaml index 2cb44e1e..302682eb 100644 --- a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_openai.yaml +++ b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_openai.yaml @@ -41,28 +41,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_lecShqIgNwQziEZxdH7XFzD3","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_utLSDP4yL9s8N9dApjMxwrHA","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Country"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UMOiC2xzUTxC9coAuiEVSmvmZM","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5iW6cFqElNxqivX4X88E8o7mfV","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -73,13 +73,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367c81cd98f94b-SJC + - 8e7f4ae52842968c-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:38 GMT + - Mon, 25 Nov 2024 05:38:26 GMT Server: - cloudflare Transfer-Encoding: @@ -91,7 +91,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '174' + - '132' openai-version: - '2020-10-01' strict-transport-security: @@ -105,23 +105,24 @@ interactions: x-ratelimit-remaining-tokens: - '199977' x-ratelimit-reset-requests: - - 25.227s + - 24.928s x-ratelimit-reset-tokens: - 6ms x-request-id: - - req_f5af46a001996c7a4c9db71df1c139eb + - req_5824fdf4a2e6d4b97abe71629d854c93 status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_lecShqIgNwQziEZxdH7XFzD3", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Country\"}"}}]}, - {"role": "tool", "tool_call_id": "call_lecShqIgNwQziEZxdH7XFzD3", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Country'', input_type=str]\n For further - information visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": - "gpt-4o-mini", "parallel_tool_calls": false, "stream": true, "stream_options": + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_utLSDP4yL9s8N9dApjMxwrHA", "function": {"arguments": + "{\"name\":\"Canada\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_utLSDP4yL9s8N9dApjMxwrHA", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''Canada'', input_type=str]\n For + further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o-mini", "parallel_tool_calls": false, "stream": true, "stream_options": {"include_usage": true}, "tool_choice": {"type": "function", "function": {"name": "return_country"}}, "tools": [{"type": "function", "function": {"name": "return_country", "parameters": {"properties": {"name": {"title": "Name", "type": "string"}}, @@ -134,12 +135,9 @@ interactions: connection: - keep-alive content-length: - - '938' + - '1045' content-type: - application/json - cookie: - - __cf_bm=uHzlI0jqu.pSi2pAubFTNeq6fPmJz5tfvnN5J3Womqk-1731749678-1.0.1.1-u5hCbyxg3fWVYgwkfVYOGMCgd7Cm2OlCCQL98Har0RraLlM_vn7gSoMmFUltlt9qfAzivRRupv6_o5u_OPY9DQ; - _cfuvid=Z2QTUH8.iKU7gXBDaD1tU2oRe6RBI5J31Rmexwn3jgk-1731749678578-0.0.1.1-604800000 host: - api.openai.com user-agent: @@ -164,28 +162,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_MWI0aMZXWYCYRPgKpVjpxnZu","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_7tzq2Puv36M9dOdFVXqrJl4y","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UMp4Sg8yg8VBLhdhUOO2XJh0ay","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5jAXSOgYSkMYvw31lTlkWPjYBa","object":"chat.completion.chunk","created":1732513107,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -196,13 +194,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367c839ed0f94b-SJC + - 8e7f4ae6e980968c-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:38 GMT + - Mon, 25 Nov 2024 05:38:27 GMT Server: - cloudflare Transfer-Encoding: @@ -214,7 +212,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '133' + - '130' openai-version: - '2020-10-01' strict-transport-security: @@ -226,13 +224,13 @@ interactions: x-ratelimit-remaining-requests: - '9996' x-ratelimit-remaining-tokens: - - '199922' + - '199923' x-ratelimit-reset-requests: - - 33.576s + - 33.287s x-ratelimit-reset-tokens: - 23ms x-request-id: - - req_d2f9147adb2646dfd4f47a12389e0e1e + - req_a9ce5c6dc41085cadaa97544e11e2bb5 status: code: 200 message: OK diff --git a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_openai.yaml b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_openai.yaml index eca411bf..6d2e7782 100644 --- a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_openai.yaml +++ b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_openai.yaml @@ -25,6 +25,8 @@ interactions: - arm64 x-stainless-async: - 'false' + x-stainless-helper-method: + - beta.chat.completions.stream x-stainless-lang: - python x-stainless-os: @@ -41,28 +43,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_ugvIEdaAkA3EwKchrQL6sCtt","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2AGPVxgioaHgJ0ysmjvWE4a9","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9ULww3MRjPquqNLMfEm0wRtvXCp","object":"chat.completion.chunk","created":1731749677,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5h5kzhX8FtXccw9UdtcvBLAS6Z","object":"chat.completion.chunk","created":1732513105,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":5,"total_tokens":58,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -73,13 +75,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367c7d7c6fcf9f-SJC + - 8e7f4adefa511742-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:37 GMT + - Mon, 25 Nov 2024 05:38:25 GMT Server: - cloudflare Transfer-Encoding: @@ -91,7 +93,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '161' + - '142' openai-version: - '2020-10-01' strict-transport-security: @@ -103,25 +105,26 @@ interactions: x-ratelimit-remaining-requests: - '9999' x-ratelimit-remaining-tokens: - - '199978' + - '199977' x-ratelimit-reset-requests: - 8.64s x-ratelimit-reset-tokens: - 6ms x-request-id: - - req_5b1cdde6bb1c37d9615f2a7f6322264f + - req_39f5a155d1bb5b52ff4716ed53c116bb status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_ugvIEdaAkA3EwKchrQL6sCtt", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Canada\"}"}}]}, - {"role": "tool", "tool_call_id": "call_ugvIEdaAkA3EwKchrQL6sCtt", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Canada'', input_type=str]\n For further - information visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": - "gpt-4o-mini", "parallel_tool_calls": false, "stream": true, "stream_options": + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_2AGPVxgioaHgJ0ysmjvWE4a9", "function": {"arguments": + "{\"name\":\"Canada\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_2AGPVxgioaHgJ0ysmjvWE4a9", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''Canada'', input_type=str]\n For + further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o-mini", "parallel_tool_calls": false, "stream": true, "stream_options": {"include_usage": true}, "tool_choice": {"type": "function", "function": {"name": "return_country"}}, "tools": [{"type": "function", "function": {"name": "return_country", "parameters": {"properties": {"name": {"title": "Name", "type": "string"}}, @@ -134,12 +137,9 @@ interactions: connection: - keep-alive content-length: - - '936' + - '1045' content-type: - application/json - cookie: - - __cf_bm=IongOs.UaKxgqpxpk2cJio1t2ANoC2tmOxc3A6LF69Q-1731749677-1.0.1.1-l.m7XlTBhSnwT2YxmLrxCxFdERRQVe2bjNJTKpj.cnlBWlYXqfutVA7.lfFBsRZ5hIkqC92SBz2BmluvkOYDEw; - _cfuvid=xilQfC5pUNDO3wH0u9UbngFg3782S5FE7fn8T37Dhds-1731749677873-0.0.1.1-604800000 host: - api.openai.com user-agent: @@ -148,6 +148,8 @@ interactions: - arm64 x-stainless-async: - 'false' + x-stainless-helper-method: + - beta.chat.completions.stream x-stainless-lang: - python x-stainless-os: @@ -164,28 +166,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_Eyrvzk11kYJM0px7uwZFVr7m","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_1BMcenaJshxLU4S6ndqKbYrL","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - data: {"id":"chatcmpl-AU9UMCyEsQwtrRJ1NQD4DQkDoIDTn","object":"chat.completion.chunk","created":1731749678,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AXM5iy0XofVHS1yfehrUoO6LZUuJ6","object":"chat.completion.chunk","created":1732513106,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[],"usage":{"prompt_tokens":125,"completion_tokens":5,"total_tokens":130,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -196,13 +198,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e367c7f3e0dcf9f-SJC + - 8e7f4ae18d3c1742-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sat, 16 Nov 2024 09:34:38 GMT + - Mon, 25 Nov 2024 05:38:26 GMT Server: - cloudflare Transfer-Encoding: @@ -214,7 +216,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '255' + - '220' openai-version: - '2020-10-01' strict-transport-security: @@ -226,13 +228,13 @@ interactions: x-ratelimit-remaining-requests: - '9998' x-ratelimit-remaining-tokens: - - '199922' + - '199923' x-ratelimit-reset-requests: - - 17.001s + - 16.863s x-ratelimit-reset-tokens: - 23ms x-request-id: - - req_2740ae04afa64140db7a7b0f3dd2516e + - req_d6456abe3299a66d7b7a4060c342d7d1 status: code: 200 message: OK From 7478c9ad3a7e57528536619ba03f364b5becf650 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:19:01 -0800 Subject: [PATCH 19/40] Add back usage for OpenaiChatModel --- src/magentic/chat_model/openai_chat_model.py | 40 ++++++++------------ src/magentic/chat_model/stream.py | 31 ++++++--------- src/magentic/streaming.py | 2 + tests/chat_model/test_openai_chat_model.py | 4 -- 4 files changed, 30 insertions(+), 47 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 13baecaf..54503a7d 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -15,7 +15,6 @@ from openai.lib.streaming.chat import ( AsyncChatCompletionStreamManager, ChatCompletionStreamEvent, - ChunkEvent, ContentDeltaEvent, ContentDoneEvent, FunctionToolCallArgumentsDeltaEvent, @@ -279,25 +278,6 @@ def get_tool_name(self, item: ChatCompletionStreamEvent) -> str: return item.name -# TODO: Move usage tracking into OpenaiStreamState -class OpenaiUsageStreamParser(StreamParser[ChatCompletionStreamEvent, Usage]): - """Filters and transforms OpenAI usage events from a stream.""" - - def is_member(self, item: ChatCompletionStreamEvent) -> TypeGuard[ChunkEvent]: - return item.type == "chunk" and bool(item.chunk.usage) - - def is_end(self, item: ChatCompletionStreamEvent) -> Literal[True]: - return True # Single event so immediately end - - def transform(self, item: ChatCompletionStreamEvent) -> Usage: - assert self.is_member(item) # noqa: S101 - assert item.chunk.usage # noqa: S101 - return Usage( - input_tokens=item.chunk.usage.prompt_tokens, - output_tokens=item.chunk.usage.completion_tokens, - ) - - class OpenaiStreamState(StreamState): def __init__(self, function_schemas): self._function_schemas = function_schemas @@ -307,10 +287,12 @@ def __init__(self, function_schemas): response_format=openai.NOT_GIVEN, ) self._current_tool_call_id: str | None = None + self.usage_ref: list[Usage] = [] def update(self, item: ChatCompletionStreamEvent) -> None: if item.type == "chunk": self._chat_completion_stream_state.handle_chunk(item.chunk) + # TODO: Should loop through tool calls if ( item.type == "chunk" and item.chunk.choices @@ -321,6 +303,14 @@ def update(self, item: ChatCompletionStreamEvent) -> None: # openai keeps index consistent for chunks from the same tool_call, but id is null # mistral has null index, but keeps id consistent self._current_tool_call_id = item.chunk.choices[0].delta.tool_calls[0].id + if item.type == "chunk" and bool(item.chunk.usage): + assert not self.usage_ref # noqa: S101 + self.usage_ref.append( + Usage( + input_tokens=item.chunk.usage.prompt_tokens, + output_tokens=item.chunk.usage.completion_tokens, + ) + ) @property def current_tool_call_id(self) -> str | None: @@ -521,10 +511,11 @@ def complete( # TODO: Consoldate these into a single parser / state object? content_parser=OpenaiContentStreamParser(), tool_parser=OpenaiToolStreamParser(), - usage_parser=OpenaiUsageStreamParser(), state=OpenaiStreamState(function_schemas=function_schemas), ) - return AssistantMessage(parse_stream(stream, output_types)) # type: ignore[return-type] + return AssistantMessage._with_usage( + parse_stream(stream, output_types), usage_ref=stream.usage_ref + ) @overload async def acomplete( @@ -603,7 +594,8 @@ async def acomplete( function_schemas=function_schemas, content_parser=OpenaiContentStreamParser(), tool_parser=OpenaiToolStreamParser(), - usage_parser=OpenaiUsageStreamParser(), state=OpenaiStreamState(function_schemas=function_schemas), ) - return AssistantMessage(await aparse_stream(stream, output_types)) # type: ignore[return-type] + return AssistantMessage._with_usage( + await aparse_stream(stream, output_types), usage_ref=stream.usage_ref + ) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index ca07ed1b..479f69a8 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -1,14 +1,14 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterable, Iterator from itertools import chain -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import Generic, TypeVar from litellm.llms.files_apis.azure import Any from pydantic import ValidationError from magentic.chat_model.base import ToolSchemaParseError from magentic.chat_model.function_schema import FunctionSchema, select_function_schema -from magentic.chat_model.message import Message +from magentic.chat_model.message import Message, Usage from magentic.streaming import ( AsyncStreamedStr, StreamedStr, @@ -18,10 +18,6 @@ async_iter, ) -if TYPE_CHECKING: - from magentic.chat_model.message import Usage - - ItemT = TypeVar("ItemT") OutputT = TypeVar("OutputT") @@ -62,6 +58,8 @@ async def aiter( class StreamState(ABC, Generic[ItemT]): + usage_ref: list[Usage] + @abstractmethod def update(self, item: ItemT) -> None: ... @@ -83,7 +81,6 @@ def __init__( function_schemas: Iterable[FunctionSchema[OutputT]], content_parser: StreamParser, tool_parser: StreamParser, - usage_parser: StreamParser, state: StreamState[ItemT], ): self._stream = stream @@ -92,13 +89,10 @@ def __init__( self._content_parser = content_parser self._tool_parser = tool_parser - self._usage_parser = usage_parser self._state = state self._wrapped_stream = apply(self._state.update, stream) - self.usage: Usage | None = None - def __next__(self) -> StreamedStr | OutputT: return self._iterator.__next__() @@ -132,12 +126,13 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: tool_call_id=self._state.current_tool_call_id, validation_error=e, ) from e - # TODO: Move usage tracking into StreamState - elif self._usage_parser.is_member(transition_item): - self.usage = self._usage_parser.transform(transition_item) elif new_transition_item := next(self._wrapped_stream, None): transition.append(new_transition_item) + @property + def usage_ref(self) -> list[Usage]: + return self._state.usage_ref + def close(self): self._stream.close() @@ -151,7 +146,6 @@ def __init__( function_schemas: Iterable[FunctionSchema[OutputT]], content_parser: StreamParser, tool_parser: StreamParser, - usage_parser: StreamParser, state: StreamState[ItemT], ): self._stream = stream @@ -160,13 +154,10 @@ def __init__( self._content_parser = content_parser self._tool_parser = tool_parser - self._usage_parser = usage_parser self._state = state self._wrapped_stream = aapply(self._state.update, stream) - self.usage: Usage | None = None - async def __anext__(self) -> AsyncStreamedStr | OutputT: return await self._iterator.__anext__() @@ -203,10 +194,12 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: tool_call_id=self._state.current_tool_call_id, validation_error=e, ) from e - elif self._usage_parser.is_member(transition_item): - self.usage = self._usage_parser.transform(transition_item) elif new_transition_item := await anext(self._wrapped_stream, None): transition.append(new_transition_item) + @property + def usage_ref(self) -> list[Usage]: + return self._state.usage_ref + async def close(self): await self._stream.close() diff --git a/src/magentic/streaming.py b/src/magentic/streaming.py index c3154167..e87af083 100644 --- a/src/magentic/streaming.py +++ b/src/magentic/streaming.py @@ -225,6 +225,8 @@ async def __aiter__(self) -> AsyncIterator[T]: yield item +# TODO: Add close method to close the underlying stream if chunks is a stream +# TODO: Make it a context manager to automatically close class StreamedStr(Iterable[str]): """A string that is generated in chunks.""" diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index 51e33ce7..26f06efe 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -186,7 +186,6 @@ def plus(a: int, b: int) -> int: assert isinstance(message.content, FunctionCall) -@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai def test_openai_chat_model_complete_usage(): chat_model = OpenaiChatModel("gpt-4o") @@ -199,7 +198,6 @@ def test_openai_chat_model_complete_usage(): assert message.usage.output_tokens > 0 -@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai def test_openai_chat_model_complete_usage_structured_output(): chat_model = OpenaiChatModel("gpt-4o") @@ -240,7 +238,6 @@ class Test(BaseModel): ) -@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai async def test_openai_chat_model_acomplete_usage(): chat_model = OpenaiChatModel("gpt-4o") @@ -253,7 +250,6 @@ async def test_openai_chat_model_acomplete_usage(): assert message.usage.output_tokens > 0 -@pytest.mark.skip("TODO: implement usage") @pytest.mark.openai async def test_openai_chat_model_acomplete_usage_structured_output(): chat_model = OpenaiChatModel("gpt-4o") From 5ca69cba760c3cc81338bfe558b7445eed67a838 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 00:42:33 -0800 Subject: [PATCH 20/40] Consolodate parsers into one --- src/magentic/chat_model/openai_chat_model.py | 181 ++++++++-------- src/magentic/chat_model/stream.py | 209 ++++++++++++------- 2 files changed, 213 insertions(+), 177 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 54503a7d..5ba45fb2 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,25 +1,17 @@ import base64 from collections.abc import ( AsyncIterator, - Awaitable, Callable, Iterable, + Iterator, Sequence, ) from enum import Enum from functools import singledispatch -from typing import Any, Generic, Literal, TypeGuard, TypeVar, cast, overload +from typing import Any, Generic, Literal, TypeVar, cast, overload import filetype import openai -from openai.lib.streaming.chat import ( - AsyncChatCompletionStreamManager, - ChatCompletionStreamEvent, - ContentDeltaEvent, - ContentDoneEvent, - FunctionToolCallArgumentsDeltaEvent, - FunctionToolCallArgumentsDoneEvent, -) from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.types.chat import ( ChatCompletionChunk, @@ -240,45 +232,68 @@ def to_dict(self) -> ChatCompletionToolParam: return {"type": "function", "function": self._function_schema.dict()} -class OpenaiContentStreamParser(StreamParser[ChatCompletionStreamEvent, str]): - """Filters and transforms OpenAI content events from a stream.""" +class OpenaiStreamParser(StreamParser[ChatCompletionChunk]): + def is_content(self, item: ChatCompletionChunk) -> bool: + return bool(item.choices and item.choices[0].delta.content) - def is_member( - self, item: ChatCompletionStreamEvent - ) -> TypeGuard[ContentDeltaEvent]: - return item.type == "content.delta" + def is_content_ended(self, item: ChatCompletionChunk) -> bool: + return self.is_tool_call(item) - def is_end(self, item: ChatCompletionStreamEvent) -> TypeGuard[ContentDoneEvent]: - return item.type == "content.done" + def get_content(self, item: ChatCompletionChunk) -> str: + if item.choices and item.choices[0].delta.content: + return item.choices[0].delta.content + return "" - def transform(self, item: ChatCompletionStreamEvent) -> str: - assert self.is_member(item) # noqa: S101 - return item.delta + def is_tool_call(self, item: ChatCompletionChunk) -> bool: + return bool(item.choices and item.choices[0].delta.tool_calls) + def get_tool_call_index(self, item: ChatCompletionChunk) -> int | None: + if ( + item.choices + and item.choices[0].delta.tool_calls + and item.choices[0].delta.tool_calls[0].index is not None + ): + return item.choices[0].delta.tool_calls[0].index + return None -class OpenaiToolStreamParser(StreamParser[ChatCompletionStreamEvent, str]): - """Filters and transforms OpenAI tool events from a stream.""" + def get_tool_call_id(self, item: ChatCompletionChunk) -> str | None: + if ( + item.choices + and item.choices[0].delta.tool_calls + and item.choices[0].delta.tool_calls[0].id + ): + return item.choices[0].delta.tool_calls[0].id + return None - def is_member( - self, item: ChatCompletionStreamEvent - ) -> TypeGuard[FunctionToolCallArgumentsDeltaEvent]: - return item.type == "tool_calls.function.arguments.delta" + def get_tool_name(self, item: ChatCompletionChunk) -> str | None: + if ( + item.choices + and item.choices[0].delta.tool_calls + and item.choices[0].delta.tool_calls[0].function + and item.choices[0].delta.tool_calls[0].function.name + ): + return item.choices[0].delta.tool_calls[0].function.name + return None - def is_end( - self, item: ChatCompletionStreamEvent - ) -> TypeGuard[FunctionToolCallArgumentsDoneEvent]: - return item.type == "tool_calls.function.arguments.done" + def get_tool_call_args(self, item: ChatCompletionChunk) -> str: + if ( + item.choices + and item.choices[0].delta.tool_calls + and item.choices[0].delta.tool_calls[0].function + and item.choices[0].delta.tool_calls[0].function.arguments + ): + return item.choices[0].delta.tool_calls[0].function.arguments + return "" - def transform(self, item: ChatCompletionStreamEvent) -> str: - assert self.is_member(item) # noqa: S101 - return item.arguments_delta - def get_tool_name(self, item: ChatCompletionStreamEvent) -> str: - assert self.is_member(item) # noqa: S101 - return item.name +class OpenaiStreamState(StreamState[ChatCompletionChunk]): + """Tracks the state of the OpenAI model output stream. + - message snapshot + - usage + - stop reason + """ -class OpenaiStreamState(StreamState): def __init__(self, function_schemas): self._function_schemas = function_schemas @@ -286,36 +301,19 @@ def __init__(self, function_schemas): input_tools=openai.NOT_GIVEN, response_format=openai.NOT_GIVEN, ) - self._current_tool_call_id: str | None = None self.usage_ref: list[Usage] = [] - def update(self, item: ChatCompletionStreamEvent) -> None: - if item.type == "chunk": - self._chat_completion_stream_state.handle_chunk(item.chunk) - # TODO: Should loop through tool calls - if ( - item.type == "chunk" - and item.chunk.choices - and item.chunk.choices[0].delta.tool_calls - and item.chunk.choices[0].delta.tool_calls[0].id - ): - # TODO: Mistral fix here ? - # openai keeps index consistent for chunks from the same tool_call, but id is null - # mistral has null index, but keeps id consistent - self._current_tool_call_id = item.chunk.choices[0].delta.tool_calls[0].id - if item.type == "chunk" and bool(item.chunk.usage): + def update(self, item: ChatCompletionChunk) -> None: + self._chat_completion_stream_state.handle_chunk(item) + if item.usage: assert not self.usage_ref # noqa: S101 self.usage_ref.append( Usage( - input_tokens=item.chunk.usage.prompt_tokens, - output_tokens=item.chunk.usage.completion_tokens, + input_tokens=item.usage.prompt_tokens, + output_tokens=item.usage.completion_tokens, ) ) - @property - def current_tool_call_id(self) -> str | None: - return self._current_tool_call_id - @property def current_message_snapshot(self) -> Message: message = ( @@ -486,8 +484,7 @@ def complete( streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) allow_string_output = str_in_output_types or streamed_str_in_output_types - # TODO: Switch to the create method to avoid possible validation addition - _stream = self._client.beta.chat.completions.stream( + response: Iterator[ChatCompletionChunk] = self._client.chat.completions.create( model=self.model, messages=_add_missing_tool_calls_responses( [message_to_openai_message(m) for m in messages] @@ -495,6 +492,7 @@ def complete( max_tokens=_if_given(self.max_tokens), seed=_if_given(self.seed), stop=_if_given(stop), + stream=True, stream_options=self._get_stream_options(), temperature=_if_given(self.temperature), tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, @@ -504,13 +502,11 @@ def complete( parallel_tool_calls=self._get_parallel_tool_calls( tools_specified=bool(tool_schemas), output_types=output_types ), - ).__enter__() # Get stream directly, without context manager + ) stream = OutputStream( - _stream, + response, function_schemas=function_schemas, - # TODO: Consoldate these into a single parser / state object? - content_parser=OpenaiContentStreamParser(), - tool_parser=OpenaiToolStreamParser(), + parser=OpenaiStreamParser(), state=OpenaiStreamState(function_schemas=function_schemas), ) return AssistantMessage._with_usage( @@ -562,38 +558,31 @@ async def acomplete( ) allow_string_output = str_in_output_types or async_streamed_str_in_output_types - response: Awaitable[AsyncIterator[ChatCompletionChunk]] = ( - self._async_client.chat.completions.create( - model=self.model, - messages=_add_missing_tool_calls_responses( - [message_to_openai_message(m) for m in messages] - ), - max_tokens=_if_given(self.max_tokens), - seed=_if_given(self.seed), - stop=_if_given(stop), - stream=True, - stream_options=self._get_stream_options(), - temperature=_if_given(self.temperature), - tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, - tool_choice=self._get_tool_choice( - tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), - parallel_tool_calls=self._get_parallel_tool_calls( - tools_specified=bool(tool_schemas), output_types=output_types - ), - ) + response: AsyncIterator[ + ChatCompletionChunk + ] = await self._async_client.chat.completions.create( + model=self.model, + messages=_add_missing_tool_calls_responses( + [message_to_openai_message(m) for m in messages] + ), + max_tokens=_if_given(self.max_tokens), + seed=_if_given(self.seed), + stop=_if_given(stop), + stream=True, + stream_options=self._get_stream_options(), + temperature=_if_given(self.temperature), + tools=[schema.to_dict() for schema in tool_schemas] or openai.NOT_GIVEN, + tool_choice=self._get_tool_choice( + tool_schemas=tool_schemas, allow_string_output=allow_string_output + ), + parallel_tool_calls=self._get_parallel_tool_calls( + tools_specified=bool(tool_schemas), output_types=output_types + ), ) - _stream = await AsyncChatCompletionStreamManager( - response, - response_format=openai.NOT_GIVEN, - input_tools=[schema.to_dict() for schema in tool_schemas] - or openai.NOT_GIVEN, - ).__aenter__() # Get stream directly, without context manager stream = AsyncOutputStream( - _stream, + response, function_schemas=function_schemas, - content_parser=OpenaiContentStreamParser(), - tool_parser=OpenaiToolStreamParser(), + parser=OpenaiStreamParser(), state=OpenaiStreamState(function_schemas=function_schemas), ) return AssistantMessage._with_usage( diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 479f69a8..5d213ece 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -22,51 +22,45 @@ OutputT = TypeVar("OutputT") -class StreamParser(ABC, Generic[ItemT, OutputT]): - """Filters and transforms items from an iterator until the end condition is met.""" +class StreamParser(ABC, Generic[ItemT]): + @abstractmethod + def is_content(self, item: ItemT) -> bool: ... - def is_member(self, item: ItemT) -> bool: - return True + @abstractmethod + def is_content_ended(self, item: ItemT) -> bool: ... @abstractmethod - def is_end(self, item: ItemT) -> bool: ... + def get_content(self, item: ItemT) -> str: ... @abstractmethod - def transform(self, item: ItemT) -> OutputT: ... - - def iter( - self, iterator: Iterator[ItemT], transition: list[ItemT] - ) -> Iterator[OutputT]: - for item in iterator: - if self.is_member(item): - yield self.transform(item) - if self.is_end(item): - assert not transition # noqa: S101 - transition.append(item) - return + def is_tool_call(self, item: ItemT) -> bool: ... - async def aiter( - self, aiterator: AsyncIterator[ItemT], transition: list[ItemT] - ) -> AsyncIterator[OutputT]: - async for item in aiterator: - if self.is_member(item): - yield self.transform(item) - if self.is_end(item): - assert not transition # noqa: S101 - transition.append(item) - return + @abstractmethod + def get_tool_call_index(self, item: ItemT) -> int | None: ... + + @abstractmethod + def get_tool_call_id(self, item: ItemT) -> str | None: ... + + @abstractmethod + def get_tool_name(self, item: ItemT) -> str | None: ... + + @abstractmethod + def get_tool_call_args(self, item: ItemT) -> str: ... class StreamState(ABC, Generic[ItemT]): + """Tracks the state of the LLM output stream. + + - message snapshot + - usage + - stop reason + """ + usage_ref: list[Usage] @abstractmethod def update(self, item: ItemT) -> None: ... - @property - @abstractmethod - def current_tool_call_id(self) -> str | None: ... - @property @abstractmethod def current_message_snapshot(self) -> Message[Any]: ... @@ -79,19 +73,15 @@ def __init__( self, stream: Iterator[ItemT], function_schemas: Iterable[FunctionSchema[OutputT]], - content_parser: StreamParser, - tool_parser: StreamParser, + parser: StreamParser[ItemT], state: StreamState[ItemT], ): self._stream = stream self._function_schemas = function_schemas - self._iterator = self.__stream__() - - self._content_parser = content_parser - self._tool_parser = tool_parser + self._parser = parser self._state = state - self._wrapped_stream = apply(self._state.update, stream) + self._iterator = self.__stream__() def __next__(self) -> StreamedStr | OutputT: return self._iterator.__next__() @@ -99,35 +89,66 @@ def __next__(self) -> StreamedStr | OutputT: def __iter__(self) -> Iterator[StreamedStr | OutputT]: yield from self._iterator + def _streamed_str( + self, stream: Iterator[ItemT], current_item_ref: list[ItemT] + ) -> Iterator[str]: + for item in stream: + if self._parser.is_content_ended(item): + # TODO: Check if output types allow for early return + assert not current_item_ref # noqa: S101 + current_item_ref.append(item) + return + yield self._parser.get_content(item) + + def _tool_call( + self, + stream: Iterator[ItemT], + current_item_ref: list[ItemT], + tool_call_index: int, + ) -> Iterator[str]: + for item in stream: + item_tool_call_index = self._parser.get_tool_call_index(item) + if item_tool_call_index and item_tool_call_index != tool_call_index: + # TODO: Check if output types allow for early return + assert not current_item_ref # noqa: S101 + current_item_ref.append(item) + return + yield self._parser.get_tool_call_args(item) + def __stream__(self) -> Iterator[StreamedStr | OutputT]: - transition = [next(self._wrapped_stream)] - while transition: - transition_item = transition.pop() - stream_with_transition = chain([transition_item], self._wrapped_stream) - if self._content_parser.is_member(transition_item): - yield StreamedStr( - self._content_parser.iter(stream_with_transition, transition) - ) - elif self._tool_parser.is_member(transition_item): - # TODO: Add new base class for tool parser - tool_name = self._tool_parser.get_tool_name(transition_item) + stream = apply(self._state.update, self._stream) + current_item_ref = [next(stream)] + while current_item_ref: + current_item = current_item_ref.pop() + if self._parser.is_content(current_item): + stream = chain([current_item], stream) + yield StreamedStr(self._streamed_str(stream, current_item_ref)) + # TODO: Make is_tool_calls to handle multiple tools + elif self._parser.is_tool_call(current_item): + # TODO: Iterate until ID is found ? + current_tool_call_id = self._parser.get_tool_call_id(current_item) + current_tool_name = self._parser.get_tool_name(current_item) function_schema = select_function_schema( - self._function_schemas, tool_name + self._function_schemas, current_tool_name ) + current_tool_call_index = self._parser.get_tool_call_index(current_item) + stream = chain([current_item], stream) try: yield function_schema.parse_args( - self._tool_parser.iter(stream_with_transition, transition) + self._tool_call( + stream, current_item_ref, current_tool_call_index + ) ) # TODO: Catch/raise unknown tool call error here except ValidationError as e: - assert self._state.current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None # noqa: S101 raise ToolSchemaParseError( output_message=self._state.current_message_snapshot, - tool_call_id=self._state.current_tool_call_id, + tool_call_id=current_tool_call_id, validation_error=e, ) from e - elif new_transition_item := next(self._wrapped_stream, None): - transition.append(new_transition_item) + elif new_current_item := next(stream, None): + current_item_ref.append(new_current_item) @property def usage_ref(self) -> list[Usage]: @@ -144,19 +165,15 @@ def __init__( self, stream: AsyncIterator[ItemT], function_schemas: Iterable[FunctionSchema[OutputT]], - content_parser: StreamParser, - tool_parser: StreamParser, + parser: StreamParser[ItemT], state: StreamState[ItemT], ): self._stream = stream self._function_schemas = function_schemas - self._iterator = self.__stream__() - - self._content_parser = content_parser - self._tool_parser = tool_parser + self._parser = parser self._state = state - self._wrapped_stream = aapply(self._state.update, stream) + self._iterator = self.__stream__() async def __anext__(self) -> AsyncStreamedStr | OutputT: return await self._iterator.__anext__() @@ -165,37 +182,67 @@ async def __aiter__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: async for item in self._iterator: yield item + async def _streamed_str( + self, stream: AsyncIterator[ItemT], current_item_ref: list[ItemT] + ) -> AsyncIterator[str]: + async for item in stream: + if self._parser.is_content_ended(item): + # TODO: Check if output types allow for early return + assert not current_item_ref # noqa: S101 + current_item_ref.append(item) + return + yield self._parser.get_content(item) + + async def _tool_call( + self, + stream: AsyncIterator[ItemT], + current_item_ref: list[ItemT], + tool_call_index: int, + ) -> AsyncIterator[str]: + async for item in stream: + item_tool_call_index = self._parser.get_tool_call_index(item) + if item_tool_call_index and item_tool_call_index != tool_call_index: + # TODO: Check if output types allow for early return + assert not current_item_ref # noqa: S101 + current_item_ref.append(item) + return + yield self._parser.get_tool_call_args(item) + async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: - transition = [await anext(self._wrapped_stream)] - while transition: - transition_item = transition.pop() - stream_with_transition = achain( - async_iter([transition_item]), self._wrapped_stream - ) - if self._content_parser.is_member(transition_item): - yield AsyncStreamedStr( - self._content_parser.aiter(stream_with_transition, transition) - ) - elif self._tool_parser.is_member(transition_item): - # TODO: Add new base class for tool parser - tool_name = self._tool_parser.get_tool_name(transition_item) + stream = aapply(self._state.update, self._stream) + current_item_ref = [await anext(stream)] + while current_item_ref: + current_item = current_item_ref.pop() + if self._parser.is_content(current_item): + stream = achain(async_iter([current_item]), stream) + yield AsyncStreamedStr(self._streamed_str(stream, current_item_ref)) + # TODO: Make is_tool_calls to handle multiple tools + elif self._parser.is_tool_call(current_item): + # TODO: Iterate until ID is found ? + current_tool_call_id = self._parser.get_tool_call_id(current_item) + current_tool_name = self._parser.get_tool_name(current_item) function_schema = select_function_schema( - self._function_schemas, tool_name + self._function_schemas, current_tool_name ) + current_tool_call_index = self._parser.get_tool_call_index(current_item) + stream = achain(async_iter([current_item]), stream) try: yield await function_schema.aparse_args( - self._tool_parser.aiter(stream_with_transition, transition) + self._tool_call( + stream, current_item_ref, current_tool_call_index + ) ) # TODO: Catch/raise unknown tool call error here except ValidationError as e: - assert self._state.current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None # noqa: S101 raise ToolSchemaParseError( output_message=self._state.current_message_snapshot, - tool_call_id=self._state.current_tool_call_id, + # TODO: Take last tool call id from the message snapshot instead? + tool_call_id=current_tool_call_id, validation_error=e, ) from e - elif new_transition_item := await anext(self._wrapped_stream, None): - transition.append(new_transition_item) + elif new_current_item := await anext(stream, None): + current_item_ref.append(new_current_item) @property def usage_ref(self) -> list[Usage]: From 4b98d255913aea559077dfea261cc5bec7d4f366 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:15:34 -0800 Subject: [PATCH 21/40] Make LitellmChatModel use new parsing format --- src/magentic/chat_model/litellm_chat_model.py | 121 +++++++++++------- src/magentic/chat_model/openai_chat_model.py | 19 ++- src/magentic/chat_model/stream.py | 4 +- ...y_chat_model_acomplete_litellm_openai.yaml | 75 +++++------ ...ry_chat_model_complete_litellm_openai.yaml | 75 +++++------ 5 files changed, 163 insertions(+), 131 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 137ec3ae..4ab980e5 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -1,7 +1,10 @@ from collections.abc import Callable, Iterable, Sequence from typing import Any, Literal, TypeVar, cast, overload +import litellm +import openai from litellm.litellm_core_utils.streaming_handler import StreamingChoices +from openai.lib.streaming.chat._completions import ChatCompletionStreamState from magentic.chat_model.base import ( ChatModel, @@ -17,13 +20,19 @@ AssistantMessage, Message, Usage, + _RawMessage, ) from magentic.chat_model.openai_chat_model import ( STR_OR_FUNCTIONCALL_TYPE, BaseFunctionToolSchema, message_to_openai_message, ) -from magentic.chat_model.stream import AsyncOutputStream, OutputStream, StreamParser +from magentic.chat_model.stream import ( + AsyncOutputStream, + OutputStream, + StreamParser, + StreamState, +) from magentic.streaming import ( AsyncStreamedStr, StreamedStr, @@ -38,60 +47,84 @@ raise ImportError(msg) from error -class LitellmContentStreamParser(StreamParser[ModelResponse, str]): - """Filters and transforms LiteLLM content chunks from a stream.""" - - def is_member(self, item: ModelResponse) -> bool: +class LitellmStreamParser(StreamParser[ModelResponse]): + def is_content(self, item: ModelResponse) -> bool: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 return bool(item.choices[0].delta.content) - def is_end(self, item: ModelResponse) -> bool: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - return bool(item.choices[0].delta.content is None) + def is_content_ended(self, item: ModelResponse) -> bool: + return self.is_tool_call(item) - def transform(self, item: ModelResponse) -> str: - assert self.is_member(item) # noqa: S101 + def get_content(self, item: ModelResponse) -> str: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 return item.choices[0].delta.content or "" - -class LitellmToolStreamParser(StreamParser[ModelResponse, str]): - """Filters and transforms LiteLLM tool chunks from a stream.""" - - def is_member(self, item: ModelResponse) -> bool: + def is_tool_call(self, item: ModelResponse) -> bool: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - return bool(item.choices[0].delta.tool_calls is not None) + return bool(item.choices[0].delta.tool_calls) - def is_end(self, item: ModelResponse) -> bool: + def get_tool_call_index(self, item: ModelResponse) -> int | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - return item.choices[0].delta.tool_calls is None + if item.choices and item.choices[0].delta.tool_calls: + return item.choices[0].delta.tool_calls[0].index + return None - def transform(self, item: ModelResponse) -> str: - assert self.is_member(item) # noqa: S101 + def get_tool_call_id(self, item: ModelResponse) -> str | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - assert item.choices[0].delta.tool_calls is not None # noqa: S101 - return item.choices[0].delta.tool_calls[0].function.arguments + if item.choices and item.choices[0].delta.tool_calls: + return item.choices[0].delta.tool_calls[0].id + return None - def get_tool_name(self, item: ModelResponse) -> str: - assert self.is_member(item) # noqa: S101 + def get_tool_name(self, item: ModelResponse) -> str | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - assert item.choices[0].delta.tool_calls is not None # noqa: S101 - assert item.choices[0].delta.tool_calls[0].function.name # noqa: S101 - return item.choices[0].delta.tool_calls[0].function.name - - -# TODO: Implement LitellmToolStreamParser -class LitellmUsageStreamParser(StreamParser[ModelResponse, Usage]): - """Filters and transforms LiteLLM tool chunks from a stream.""" + if ( + item.choices + and item.choices[0].delta.tool_calls + and item.choices[0].delta.tool_calls[0].function.name + ): + return item.choices[0].delta.tool_calls[0].function.name + return None + + def get_tool_call_args(self, item: ModelResponse) -> str: + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + if item.choices and item.choices[0].delta.tool_calls: + return item.choices[0].delta.tool_calls[0].function.arguments + return "" - def is_member(self, item: ModelResponse) -> bool: - return False - def is_end(self, item: ModelResponse) -> bool: - return True +class LitellmStreamState(StreamState[ModelResponse]): + def __init__(self): + self._chat_completion_stream_state = ChatCompletionStreamState( + input_tools=openai.NOT_GIVEN, + response_format=openai.NOT_GIVEN, + ) + self.usage_ref: list[Usage] = [] + + def update(self, item: ModelResponse) -> None: + # usage attribute is required inside ChatCompletionStreamState.handle_chunk + # and litellm requires that this is not None for its total usage calculation + if not hasattr(item, "usage"): + item.usage = litellm.Usage() # type: ignore[attr-defined] + self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type] + usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined] + # Ignore usages with 0 tokens + if usage and usage.prompt_tokens and usage.completion_tokens: + # assert not self.usage_ref + self.usage_ref.append( + Usage( + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + ) + ) - def transform(self, item: ModelResponse) -> Usage: - return Usage(input_tokens=0, output_tokens=0) + @property + def current_message_snapshot(self) -> Message: + snapshot = self._chat_completion_stream_state.current_completion_snapshot + message = snapshot.choices[0].message + # Fix incorrectly concatenated role + message.role = "assistant" + # TODO: Possible to return AssistantMessage here? + return _RawMessage(message.model_dump()) R = TypeVar("R") @@ -206,6 +239,7 @@ def complete( metadata=self.metadata, stop=stop, stream=True, + # TODO: Add usage for LitellmChatModel temperature=self.temperature, tools=[schema.to_dict() for schema in tool_schemas] or None, tool_choice=self._get_tool_choice( @@ -216,9 +250,8 @@ def complete( stream = OutputStream( stream=response, function_schemas=function_schemas, - content_parser=LitellmContentStreamParser(), - tool_parser=LitellmToolStreamParser(), - usage_parser=LitellmUsageStreamParser(), + parser=LitellmStreamParser(), + state=LitellmStreamState(), ) return AssistantMessage(parse_stream(stream, output_types)) @@ -276,6 +309,7 @@ async def acomplete( metadata=self.metadata, stop=stop, stream=True, + # TODO: Add usage for LitellmChatModel temperature=self.temperature, tools=[schema.to_dict() for schema in tool_schemas] or None, tool_choice=self._get_tool_choice( @@ -286,8 +320,7 @@ async def acomplete( stream = AsyncOutputStream( stream=response, function_schemas=function_schemas, - content_parser=LitellmContentStreamParser(), - tool_parser=LitellmToolStreamParser(), - usage_parser=LitellmUsageStreamParser(), + parser=LitellmStreamParser(), + state=LitellmStreamState(), ) return AssistantMessage(await aparse_stream(stream, output_types)) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 5ba45fb2..b38cf7c1 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -78,7 +78,9 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar @message_to_openai_message.register(_RawMessage) def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam: - # TODO: Validate the message content + assert isinstance(message.content, dict) # noqa: S101 + assert "role" in message.content # noqa: S101 + assert "content" in message.content # noqa: S101 return message.content # type: ignore[no-any-return] @@ -294,9 +296,7 @@ class OpenaiStreamState(StreamState[ChatCompletionChunk]): - stop reason """ - def __init__(self, function_schemas): - self._function_schemas = function_schemas - + def __init__(self): self._chat_completion_stream_state = ChatCompletionStreamState( input_tools=openai.NOT_GIVEN, response_format=openai.NOT_GIVEN, @@ -316,11 +316,8 @@ def update(self, item: ChatCompletionChunk) -> None: @property def current_message_snapshot(self) -> Message: - message = ( - self._chat_completion_stream_state.current_completion_snapshot.choices[ - 0 - ].message - ) + snapshot = self._chat_completion_stream_state.current_completion_snapshot + message = snapshot.choices[0].message # TODO: Possible to return AssistantMessage here? return _RawMessage(message.model_dump()) @@ -507,7 +504,7 @@ def complete( response, function_schemas=function_schemas, parser=OpenaiStreamParser(), - state=OpenaiStreamState(function_schemas=function_schemas), + state=OpenaiStreamState(), ) return AssistantMessage._with_usage( parse_stream(stream, output_types), usage_ref=stream.usage_ref @@ -583,7 +580,7 @@ async def acomplete( response, function_schemas=function_schemas, parser=OpenaiStreamParser(), - state=OpenaiStreamState(function_schemas=function_schemas), + state=OpenaiStreamState(), ) return AssistantMessage._with_usage( await aparse_stream(stream, output_types), usage_ref=stream.usage_ref diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 5d213ece..3bd15f8b 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -94,7 +94,7 @@ def _streamed_str( ) -> Iterator[str]: for item in stream: if self._parser.is_content_ended(item): - # TODO: Check if output types allow for early return + # TODO: Check if output types allow for early return and raise if not assert not current_item_ref # noqa: S101 current_item_ref.append(item) return @@ -109,7 +109,7 @@ def _tool_call( for item in stream: item_tool_call_index = self._parser.get_tool_call_index(item) if item_tool_call_index and item_tool_call_index != tool_call_index: - # TODO: Check if output types allow for early return + # TODO: Check if output types allow for early return and raise if not assert not current_item_ref # noqa: S101 current_item_ref.append(item) return diff --git a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_litellm_openai.yaml b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_litellm_openai.yaml index 87b11758..60f2f570 100644 --- a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_litellm_openai.yaml +++ b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_acomplete_litellm_openai.yaml @@ -33,7 +33,7 @@ interactions: x-stainless-raw-response: - 'true' x-stainless-retry-count: - - '1' + - '0' x-stainless-runtime: - CPython x-stainless-runtime-version: @@ -42,25 +42,25 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_ZsPgGThV4XYMbjLcJjngBbUb","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + string: 'data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_i18HmT82n9PGc3mC6bAV7aXm","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Canada"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhFJhWQohwNWkFZ4g3gdxmE7bs","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + data: {"id":"chatcmpl-AY3jeRjlAKpYEV8HODjFYXYB67FlJ","object":"chat.completion.chunk","created":1732680874,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} data: [DONE] @@ -71,13 +71,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e3dfc5abc25645e-SJC + - 8e8f4acbbf48fa32-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 17 Nov 2024 07:25:15 GMT + - Wed, 27 Nov 2024 04:14:35 GMT Server: - cloudflare Transfer-Encoding: @@ -89,7 +89,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '159' + - '118' openai-version: - '2020-10-01' strict-transport-security: @@ -103,26 +103,27 @@ interactions: x-ratelimit-remaining-tokens: - '199978' x-ratelimit-reset-requests: - - 24.784s + - 22.868s x-ratelimit-reset-tokens: - 6ms x-request-id: - - req_f8598a9c42d753bba944e8e4b8f405e6 + - req_efe987f6575e1d241fa34d5a2126225c status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_ZsPgGThV4XYMbjLcJjngBbUb", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Canada\"}"}}]}, - {"role": "tool", "tool_call_id": "call_ZsPgGThV4XYMbjLcJjngBbUb", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Canada'', input_type=str]\n For further - information visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": - "gpt-4o-mini", "stream": true, "tool_choice": {"type": "function", "function": - {"name": "return_country"}}, "tools": [{"type": "function", "function": {"name": - "return_country", "parameters": {"properties": {"name": {"title": "Name", "type": - "string"}}, "required": ["name"], "type": "object"}}}]}' + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_i18HmT82n9PGc3mC6bAV7aXm", "function": {"arguments": + "{\"name\":\"Canada\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_i18HmT82n9PGc3mC6bAV7aXm", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''Canada'', input_type=str]\n For + further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o-mini", "stream": true, "tool_choice": {"type": "function", + "function": {"name": "return_country"}}, "tools": [{"type": "function", "function": + {"name": "return_country", "parameters": {"properties": {"name": {"title": "Name", + "type": "string"}}, "required": ["name"], "type": "object"}}}]}' headers: accept: - application/json @@ -131,7 +132,7 @@ interactions: connection: - keep-alive content-length: - - '863' + - '972' content-type: - application/json host: @@ -160,25 +161,25 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_EMNqE0XUFpOKbZaSHqCbbo9H","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + string: 'data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_KK1SxLUkLnZAdf88EfFdFOoV","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwhAs71Rws8mFQuhs3PAeRXzEn4","object":"chat.completion.chunk","created":1731828315,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_9b78b61c52","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + data: {"id":"chatcmpl-AY3jfwOFPGQ40r0K3uGluzm6jKA9c","object":"chat.completion.chunk","created":1732680875,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} data: [DONE] @@ -189,13 +190,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e3dfc5cbdaa645e-SJC + - 8e8f4acda93efa32-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 17 Nov 2024 07:25:15 GMT + - Wed, 27 Nov 2024 04:14:35 GMT Server: - cloudflare Transfer-Encoding: @@ -207,7 +208,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '317' + - '315' openai-version: - '2020-10-01' strict-transport-security: @@ -221,11 +222,11 @@ interactions: x-ratelimit-remaining-tokens: - '199922' x-ratelimit-reset-requests: - - 33.112s + - 31.204s x-ratelimit-reset-tokens: - 23ms x-request-id: - - req_96c61728b21ad535d6bff2a9140a03a4 + - req_75ce8cf8ead16fbcebfa4f6ed9ca1ccf status: code: 200 message: OK diff --git a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_litellm_openai.yaml b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_litellm_openai.yaml index 2e0df828..d6ac20cb 100644 --- a/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_litellm_openai.yaml +++ b/tests/chat_model/cassettes/test_retry_chat_model/test_retry_chat_model_complete_litellm_openai.yaml @@ -42,25 +42,25 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_pZHNlBWDbuQmYbidOEik4lYu","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + string: 'data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_lB2W7ToZQx1KUeNj7US5se75","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Country"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Australia"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgzqd1Ii0FWSoryyv3Vasgy8Eh","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + data: {"id":"chatcmpl-AY3jbXL9GkFlYh7fPTgMPOTZpd9Q3","object":"chat.completion.chunk","created":1732680871,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} data: [DONE] @@ -71,13 +71,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e3dfc539ebfce84-SJC + - 8e8f4ab8aaae17de-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 17 Nov 2024 07:25:14 GMT + - Wed, 27 Nov 2024 04:14:32 GMT Server: - cloudflare Transfer-Encoding: @@ -89,7 +89,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '139' + - '439' openai-version: - '2020-10-01' strict-transport-security: @@ -107,22 +107,23 @@ interactions: x-ratelimit-reset-tokens: - 6ms x-request-id: - - req_1c81b5404787e8c0c0fa39a40a5add39 + - req_24555d46defa5b03d2b639b26a1ec654 status: code: 200 message: OK - request: - body: '{"messages": [{"role": "user", "content": "Return a country."}, {"role": - "assistant", "content": null, "tool_calls": [{"id": "call_pZHNlBWDbuQmYbidOEik4lYu", - "type": "function", "function": {"name": "return_country", "arguments": "{\"name\":\"Country\"}"}}]}, - {"role": "tool", "tool_call_id": "call_pZHNlBWDbuQmYbidOEik4lYu", "content": - "1 validation error for Country\nname\n Value error, Country must be Ireland. - [type=value_error, input_value=''Country'', input_type=str]\n For further - information visit https://errors.pydantic.dev/2.9/v/value_error"}], "model": - "gpt-4o-mini", "stream": true, "tool_choice": {"type": "function", "function": - {"name": "return_country"}}, "tools": [{"type": "function", "function": {"name": - "return_country", "parameters": {"properties": {"name": {"title": "Name", "type": - "string"}}, "required": ["name"], "type": "object"}}}]}' + body: '{"messages": [{"role": "user", "content": "Return a country."}, {"content": + null, "refusal": null, "role": "assistant", "audio": null, "function_call": + null, "tool_calls": [{"id": "call_lB2W7ToZQx1KUeNj7US5se75", "function": {"arguments": + "{\"name\":\"Australia\"}", "name": "return_country", "parsed_arguments": null}, + "type": "function", "index": 0}], "parsed": null}, {"role": "tool", "tool_call_id": + "call_lB2W7ToZQx1KUeNj7US5se75", "content": "1 validation error for Country\nname\n Value + error, Country must be Ireland. [type=value_error, input_value=''Australia'', + input_type=str]\n For further information visit https://errors.pydantic.dev/2.9/v/value_error"}], + "model": "gpt-4o-mini", "stream": true, "tool_choice": {"type": "function", + "function": {"name": "return_country"}}, "tools": [{"type": "function", "function": + {"name": "return_country", "parameters": {"properties": {"name": {"title": "Name", + "type": "string"}}, "required": ["name"], "type": "object"}}}]}' headers: accept: - application/json @@ -131,7 +132,7 @@ interactions: connection: - keep-alive content-length: - - '865' + - '978' content-type: - application/json host: @@ -160,25 +161,25 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_L9h05hjrL8hqC935gkuWhRXN","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + string: 'data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_GIpaN9ZLbC2qxUsQYARPQRDO","type":"function","function":{"name":"return_country","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"name"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Ireland"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-AUTwgV2Ef5xeCpenezTaQr3IFb6WY","object":"chat.completion.chunk","created":1731828314,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + data: {"id":"chatcmpl-AY3jdjAv7PKYqsZYyBXB2QntEWtTt","object":"chat.completion.chunk","created":1732680873,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0705bf87c0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} data: [DONE] @@ -186,14 +187,16 @@ interactions: ' headers: + CF-Cache-Status: + - DYNAMIC CF-RAY: - - 8e3dfc556828ce84-SJC + - 8e8f4abf2a8c17de-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 17 Nov 2024 07:25:14 GMT + - Wed, 27 Nov 2024 04:14:34 GMT Server: - cloudflare Transfer-Encoding: @@ -204,10 +207,8 @@ interactions: - X-Request-ID alt-svc: - h3=":443"; ma=86400 - cf-cache-status: - - DYNAMIC openai-processing-ms: - - '130' + - '1524' openai-version: - '2020-10-01' strict-transport-security: @@ -221,11 +222,11 @@ interactions: x-ratelimit-remaining-tokens: - '199922' x-ratelimit-reset-requests: - - 16.993s + - 16.281s x-ratelimit-reset-tokens: - 23ms x-request-id: - - req_cb0ee756dfff6d657b46133aae317271 + - req_c104cb3bc8f86806512bc2c26f592d68 status: code: 200 message: OK From 4fd7b4b5ffe81093facdbf9a45b35d541276769f Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:20:00 -0800 Subject: [PATCH 22/40] Fix litellm_ollama --- src/magentic/chat_model/litellm_chat_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 4ab980e5..2e240c98 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -101,10 +101,12 @@ def __init__(self): self.usage_ref: list[Usage] = [] def update(self, item: ModelResponse) -> None: - # usage attribute is required inside ChatCompletionStreamState.handle_chunk - # and litellm requires that this is not None for its total usage calculation + # Patch attributes required inside ChatCompletionStreamState.handle_chunk if not hasattr(item, "usage"): + # litellm requires usage is not None for its total usage calculation item.usage = litellm.Usage() # type: ignore[attr-defined] + if not hasattr(item, "refusal"): + item.choices[0].delta.refusal = None # type: ignore[attr-defined] self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type] usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined] # Ignore usages with 0 tokens From 7e7e64baf1ca974deafca43d58849ef6b2c74213 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 22:45:45 -0800 Subject: [PATCH 23/40] Handle multiple tools in a chunk, for Mistral --- src/magentic/chat_model/litellm_chat_model.py | 34 ++-- src/magentic/chat_model/openai_chat_model.py | 62 +++---- src/magentic/chat_model/stream.py | 153 ++++++++++-------- 3 files changed, 119 insertions(+), 130 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 2e240c98..ef3d1298 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -29,6 +29,7 @@ ) from magentic.chat_model.stream import ( AsyncOutputStream, + FunctionCallChunk, OutputStream, StreamParser, StreamState, @@ -63,33 +64,16 @@ def is_tool_call(self, item: ModelResponse) -> bool: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 return bool(item.choices[0].delta.tool_calls) - def get_tool_call_index(self, item: ModelResponse) -> int | None: + def iter_tool_calls(self, item: ModelResponse) -> Iterable[FunctionCallChunk]: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 if item.choices and item.choices[0].delta.tool_calls: - return item.choices[0].delta.tool_calls[0].index - return None - - def get_tool_call_id(self, item: ModelResponse) -> str | None: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - if item.choices and item.choices[0].delta.tool_calls: - return item.choices[0].delta.tool_calls[0].id - return None - - def get_tool_name(self, item: ModelResponse) -> str | None: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - if ( - item.choices - and item.choices[0].delta.tool_calls - and item.choices[0].delta.tool_calls[0].function.name - ): - return item.choices[0].delta.tool_calls[0].function.name - return None - - def get_tool_call_args(self, item: ModelResponse) -> str: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - if item.choices and item.choices[0].delta.tool_calls: - return item.choices[0].delta.tool_calls[0].function.arguments - return "" + for tool_call in item.choices[0].delta.tool_calls: + if tool_call.function: + yield FunctionCallChunk( + id=tool_call.id, + name=tool_call.function.name, + args=tool_call.function.arguments, + ) class LitellmStreamState(StreamState[ModelResponse]): diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index b38cf7c1..c6163acb 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -44,6 +44,7 @@ ) from magentic.chat_model.stream import ( AsyncOutputStream, + FunctionCallChunk, OutputStream, StreamParser, StreamState, @@ -249,43 +250,15 @@ def get_content(self, item: ChatCompletionChunk) -> str: def is_tool_call(self, item: ChatCompletionChunk) -> bool: return bool(item.choices and item.choices[0].delta.tool_calls) - def get_tool_call_index(self, item: ChatCompletionChunk) -> int | None: - if ( - item.choices - and item.choices[0].delta.tool_calls - and item.choices[0].delta.tool_calls[0].index is not None - ): - return item.choices[0].delta.tool_calls[0].index - return None - - def get_tool_call_id(self, item: ChatCompletionChunk) -> str | None: - if ( - item.choices - and item.choices[0].delta.tool_calls - and item.choices[0].delta.tool_calls[0].id - ): - return item.choices[0].delta.tool_calls[0].id - return None - - def get_tool_name(self, item: ChatCompletionChunk) -> str | None: - if ( - item.choices - and item.choices[0].delta.tool_calls - and item.choices[0].delta.tool_calls[0].function - and item.choices[0].delta.tool_calls[0].function.name - ): - return item.choices[0].delta.tool_calls[0].function.name - return None - - def get_tool_call_args(self, item: ChatCompletionChunk) -> str: - if ( - item.choices - and item.choices[0].delta.tool_calls - and item.choices[0].delta.tool_calls[0].function - and item.choices[0].delta.tool_calls[0].function.arguments - ): - return item.choices[0].delta.tool_calls[0].function.arguments - return "" + def iter_tool_calls(self, item: ChatCompletionChunk) -> Iterator[FunctionCallChunk]: + if item.choices and item.choices[0].delta.tool_calls: + for tool_call in item.choices[0].delta.tool_calls: + if tool_call.function: + yield FunctionCallChunk( + id=tool_call.id, + name=tool_call.function.name, + args=tool_call.function.arguments, + ) class OpenaiStreamState(StreamState[ChatCompletionChunk]): @@ -303,7 +276,22 @@ def __init__(self): ) self.usage_ref: list[Usage] = [] + # Keep track of tool call index to add this to Mistral tool calls + self._current_tool_call_index: int = -1 + self._seen_tool_call_ids: set[str] = set() + def update(self, item: ChatCompletionChunk) -> None: + # Add tool call index for Mistral tool calls to make compatible with OpenAI + # TODO: Remove this fix when MistralChatModel switched to mistral python package + if item.choices: + for tool_call_chunk in item.choices[0].delta.tool_calls or []: + if ( + tool_call_chunk.id is not None + and tool_call_chunk.id not in self._seen_tool_call_ids + ): + self._current_tool_call_index += 1 + self._seen_tool_call_ids.add(tool_call_chunk.id) + tool_call_chunk.index = self._current_tool_call_index self._chat_completion_stream_state.handle_chunk(item) if item.usage: assert not self.usage_ref # noqa: S101 diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 3bd15f8b..ffde1b7b 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterable, Iterator from itertools import chain -from typing import Generic, TypeVar +from typing import Generic, NamedTuple, TypeVar from litellm.llms.files_apis.azure import Any from pydantic import ValidationError @@ -22,6 +22,12 @@ OutputT = TypeVar("OutputT") +class FunctionCallChunk(NamedTuple): + id: str | None + name: str | None + args: str | None + + class StreamParser(ABC, Generic[ItemT]): @abstractmethod def is_content(self, item: ItemT) -> bool: ... @@ -36,16 +42,7 @@ def get_content(self, item: ItemT) -> str: ... def is_tool_call(self, item: ItemT) -> bool: ... @abstractmethod - def get_tool_call_index(self, item: ItemT) -> int | None: ... - - @abstractmethod - def get_tool_call_id(self, item: ItemT) -> str | None: ... - - @abstractmethod - def get_tool_name(self, item: ItemT) -> str | None: ... - - @abstractmethod - def get_tool_call_args(self, item: ItemT) -> str: ... + def iter_tool_calls(self, item: ItemT) -> Iterable[FunctionCallChunk]: ... class StreamState(ABC, Generic[ItemT]): @@ -92,6 +89,8 @@ def __iter__(self) -> Iterator[StreamedStr | OutputT]: def _streamed_str( self, stream: Iterator[ItemT], current_item_ref: list[ItemT] ) -> Iterator[str]: + # TODO: Yield item then check if next ends? + # To ensure no ended immediately if both content and tool calls are present for item in stream: if self._parser.is_content_ended(item): # TODO: Check if output types allow for early return and raise if not @@ -102,18 +101,19 @@ def _streamed_str( def _tool_call( self, - stream: Iterator[ItemT], - current_item_ref: list[ItemT], - tool_call_index: int, + stream: Iterator[FunctionCallChunk], + current_tool_call_ref: list[FunctionCallChunk], + current_tool_call_id: str, ) -> Iterator[str]: for item in stream: - item_tool_call_index = self._parser.get_tool_call_index(item) - if item_tool_call_index and item_tool_call_index != tool_call_index: + # Only end the stream if we encounter a new tool call + # so that the whole stream is consumed including stop_reason/usage chunks + if item.id and item.id != current_tool_call_id: # TODO: Check if output types allow for early return and raise if not - assert not current_item_ref # noqa: S101 - current_item_ref.append(item) + assert not current_tool_call_ref # noqa: S101 + current_tool_call_ref.append(item) return - yield self._parser.get_tool_call_args(item) + yield item.args or "" def __stream__(self) -> Iterator[StreamedStr | OutputT]: stream = apply(self._state.update, self._stream) @@ -125,28 +125,36 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: yield StreamedStr(self._streamed_str(stream, current_item_ref)) # TODO: Make is_tool_calls to handle multiple tools elif self._parser.is_tool_call(current_item): - # TODO: Iterate until ID is found ? - current_tool_call_id = self._parser.get_tool_call_id(current_item) - current_tool_name = self._parser.get_tool_name(current_item) - function_schema = select_function_schema( - self._function_schemas, current_tool_name + tool_calls_stream = ( + tool_call_chunk + for item in chain([current_item], stream) + for tool_call_chunk in self._parser.iter_tool_calls(item) ) - current_tool_call_index = self._parser.get_tool_call_index(current_item) - stream = chain([current_item], stream) - try: - yield function_schema.parse_args( - self._tool_call( - stream, current_item_ref, current_tool_call_index - ) - ) - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: + tool_call_ref = [next(tool_calls_stream)] + while tool_call_ref: + current_tool_call_chunk = tool_call_ref.pop() + current_tool_call_id = current_tool_call_chunk.id assert current_tool_call_id is not None # noqa: S101 - raise ToolSchemaParseError( - output_message=self._state.current_message_snapshot, - tool_call_id=current_tool_call_id, - validation_error=e, - ) from e + assert current_tool_call_chunk.name is not None # noqa: S101 + function_schema = select_function_schema( + self._function_schemas, current_tool_call_chunk.name + ) + try: + yield function_schema.parse_args( + self._tool_call( + chain([current_tool_call_chunk], tool_calls_stream), + tool_call_ref, + current_tool_call_id, + ) + ) + # TODO: Catch/raise unknown tool call error here + except ValidationError as e: + assert current_tool_call_id is not None # noqa: S101 + raise ToolSchemaParseError( + output_message=self._state.current_message_snapshot, + tool_call_id=current_tool_call_id, + validation_error=e, + ) from e elif new_current_item := next(stream, None): current_item_ref.append(new_current_item) @@ -195,18 +203,17 @@ async def _streamed_str( async def _tool_call( self, - stream: AsyncIterator[ItemT], - current_item_ref: list[ItemT], - tool_call_index: int, + stream: AsyncIterator[FunctionCallChunk], + current_tool_call_ref: list[FunctionCallChunk], + current_tool_call_id: str, ) -> AsyncIterator[str]: async for item in stream: - item_tool_call_index = self._parser.get_tool_call_index(item) - if item_tool_call_index and item_tool_call_index != tool_call_index: + if item.id and item.id != current_tool_call_id: # TODO: Check if output types allow for early return - assert not current_item_ref # noqa: S101 - current_item_ref.append(item) + assert not current_tool_call_ref # noqa: S101 + current_tool_call_ref.append(item) return - yield self._parser.get_tool_call_args(item) + yield item.args or "" async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: stream = aapply(self._state.update, self._stream) @@ -218,29 +225,39 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: yield AsyncStreamedStr(self._streamed_str(stream, current_item_ref)) # TODO: Make is_tool_calls to handle multiple tools elif self._parser.is_tool_call(current_item): - # TODO: Iterate until ID is found ? - current_tool_call_id = self._parser.get_tool_call_id(current_item) - current_tool_name = self._parser.get_tool_name(current_item) - function_schema = select_function_schema( - self._function_schemas, current_tool_name + tool_calls_stream = ( + tool_call_chunk + async for item in achain(async_iter([current_item]), stream) + for tool_call_chunk in self._parser.iter_tool_calls(item) ) - current_tool_call_index = self._parser.get_tool_call_index(current_item) - stream = achain(async_iter([current_item]), stream) - try: - yield await function_schema.aparse_args( - self._tool_call( - stream, current_item_ref, current_tool_call_index - ) - ) - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: + tool_call_ref = [await anext(tool_calls_stream)] + while tool_call_ref: + current_tool_call_chunk = tool_call_ref.pop() + current_tool_call_id = current_tool_call_chunk.id assert current_tool_call_id is not None # noqa: S101 - raise ToolSchemaParseError( - output_message=self._state.current_message_snapshot, - # TODO: Take last tool call id from the message snapshot instead? - tool_call_id=current_tool_call_id, - validation_error=e, - ) from e + assert current_tool_call_chunk.name is not None # noqa: S101 + function_schema = select_function_schema( + self._function_schemas, current_tool_call_chunk.name + ) + try: + yield await function_schema.aparse_args( + self._tool_call( + achain( + async_iter([current_tool_call_chunk]), + tool_calls_stream, + ), + tool_call_ref, + current_tool_call_id, + ) + ) + # TODO: Catch/raise unknown tool call error here + except ValidationError as e: + assert current_tool_call_id is not None # noqa: S101 + raise ToolSchemaParseError( + output_message=self._state.current_message_snapshot, + tool_call_id=current_tool_call_id, + validation_error=e, + ) from e elif new_current_item := await anext(stream, None): current_item_ref.append(new_current_item) From cfa0815e64122ddf59ad852cd0b4a6221ad6dc51 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 22:52:18 -0800 Subject: [PATCH 24/40] Remove anthropic context manager usage --- .../chat_model/anthropic_chat_model.py | 92 +++++++++---------- 1 file changed, 43 insertions(+), 49 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index dc54d567..8b115a9c 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -548,30 +548,27 @@ def complete( system, messages = _extract_system_message(messages) - def _response_generator() -> Iterator[MessageStreamEvent]: - with self._client.messages.stream( - model=self.model, - messages=_combine_messages( - [message_to_anthropic_message(m) for m in messages] - ), - max_tokens=self.max_tokens, - stop_sequences=stop or anthropic.NOT_GIVEN, - system=system, - temperature=( - self.temperature - if self.temperature is not None - else anthropic.NOT_GIVEN - ), - tools=( - [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN - ), - tool_choice=self._get_tool_choice( - tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), - ) as stream: - yield from stream - - response = _response_generator() + response: Iterator[MessageStreamEvent] = self._client.messages.stream( + model=self.model, + messages=_combine_messages( + [message_to_anthropic_message(m) for m in messages] + ), + max_tokens=self.max_tokens, + stop_sequences=stop or anthropic.NOT_GIVEN, + system=system, + temperature=( + self.temperature + if self.temperature is not None + else anthropic.NOT_GIVEN + ), + tools=( + [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN + ), + tool_choice=self._get_tool_choice( + tool_schemas=tool_schemas, allow_string_output=allow_string_output + ), + ).__enter__() + usage_ref, response = _create_usage_ref(response) message_start_chunk = next(response) @@ -660,31 +657,28 @@ async def acomplete( system, messages = _extract_system_message(messages) - async def _response_generator() -> AsyncIterator[MessageStreamEvent]: - async with self._async_client.messages.stream( - model=self.model, - messages=_combine_messages( - [message_to_anthropic_message(m) for m in messages] - ), - max_tokens=self.max_tokens, - stop_sequences=stop or anthropic.NOT_GIVEN, - system=system, - temperature=( - self.temperature - if self.temperature is not None - else anthropic.NOT_GIVEN - ), - tools=( - [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN - ), - tool_choice=self._get_tool_choice( - tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), - ) as stream: - async for chunk in stream: - yield chunk - - response = _response_generator() + response: AsyncIterator[ + MessageStreamEvent + ] = await self._async_client.messages.stream( + model=self.model, + messages=_combine_messages( + [message_to_anthropic_message(m) for m in messages] + ), + max_tokens=self.max_tokens, + stop_sequences=stop or anthropic.NOT_GIVEN, + system=system, + temperature=( + self.temperature + if self.temperature is not None + else anthropic.NOT_GIVEN + ), + tools=( + [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN + ), + tool_choice=self._get_tool_choice( + tool_schemas=tool_schemas, allow_string_output=allow_string_output + ), + ).__aenter__() usage_ref, response = _create_usage_ref_async(response) message_start_chunk = await anext(response) From ab407b403d9eb25e1578d56133082414687cc51f Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 22:56:15 -0800 Subject: [PATCH 25/40] Add _if_given helper for anthropic --- .../chat_model/anthropic_chat_model.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 8b115a9c..e94749e1 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -433,6 +433,10 @@ async def agenerator( return usage_ref, agenerator(response) +def _if_given(value: T | None) -> T | anthropic.NotGiven: + return value if value is not None else anthropic.NOT_GIVEN + + R = TypeVar("R") @@ -554,16 +558,10 @@ def complete( [message_to_anthropic_message(m) for m in messages] ), max_tokens=self.max_tokens, - stop_sequences=stop or anthropic.NOT_GIVEN, + stop_sequences=_if_given(stop), system=system, - temperature=( - self.temperature - if self.temperature is not None - else anthropic.NOT_GIVEN - ), - tools=( - [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN - ), + temperature=_if_given(self.temperature), + tools=[schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), @@ -665,16 +663,10 @@ async def acomplete( [message_to_anthropic_message(m) for m in messages] ), max_tokens=self.max_tokens, - stop_sequences=stop or anthropic.NOT_GIVEN, + stop_sequences=_if_given(stop), system=system, - temperature=( - self.temperature - if self.temperature is not None - else anthropic.NOT_GIVEN - ), - tools=( - [schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN - ), + temperature=_if_given(self.temperature), + tools=[schema.to_dict() for schema in tool_schemas] or anthropic.NOT_GIVEN, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), From 9ab2265e6316ac11090f6639735dd0ec21a7af20 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:08:01 -0800 Subject: [PATCH 26/40] Allow parser.get_content to return None --- src/magentic/chat_model/litellm_chat_model.py | 4 ++-- src/magentic/chat_model/openai_chat_model.py | 4 ++-- src/magentic/chat_model/stream.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index ef3d1298..742be670 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -56,9 +56,9 @@ def is_content(self, item: ModelResponse) -> bool: def is_content_ended(self, item: ModelResponse) -> bool: return self.is_tool_call(item) - def get_content(self, item: ModelResponse) -> str: + def get_content(self, item: ModelResponse) -> str | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - return item.choices[0].delta.content or "" + return item.choices[0].delta.content def is_tool_call(self, item: ModelResponse) -> bool: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index c6163acb..b6788a77 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -242,10 +242,10 @@ def is_content(self, item: ChatCompletionChunk) -> bool: def is_content_ended(self, item: ChatCompletionChunk) -> bool: return self.is_tool_call(item) - def get_content(self, item: ChatCompletionChunk) -> str: + def get_content(self, item: ChatCompletionChunk) -> str | None: if item.choices and item.choices[0].delta.content: return item.choices[0].delta.content - return "" + return None def is_tool_call(self, item: ChatCompletionChunk) -> bool: return bool(item.choices and item.choices[0].delta.tool_calls) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index ffde1b7b..13791cd1 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -36,7 +36,7 @@ def is_content(self, item: ItemT) -> bool: ... def is_content_ended(self, item: ItemT) -> bool: ... @abstractmethod - def get_content(self, item: ItemT) -> str: ... + def get_content(self, item: ItemT) -> str | None: ... @abstractmethod def is_tool_call(self, item: ItemT) -> bool: ... @@ -89,15 +89,14 @@ def __iter__(self) -> Iterator[StreamedStr | OutputT]: def _streamed_str( self, stream: Iterator[ItemT], current_item_ref: list[ItemT] ) -> Iterator[str]: - # TODO: Yield item then check if next ends? - # To ensure no ended immediately if both content and tool calls are present for item in stream: + if content := self._parser.get_content(item): + yield content if self._parser.is_content_ended(item): # TODO: Check if output types allow for early return and raise if not assert not current_item_ref # noqa: S101 current_item_ref.append(item) return - yield self._parser.get_content(item) def _tool_call( self, @@ -194,12 +193,13 @@ async def _streamed_str( self, stream: AsyncIterator[ItemT], current_item_ref: list[ItemT] ) -> AsyncIterator[str]: async for item in stream: + if content := self._parser.get_content(item): + yield content if self._parser.is_content_ended(item): # TODO: Check if output types allow for early return assert not current_item_ref # noqa: S101 current_item_ref.append(item) return - yield self._parser.get_content(item) async def _tool_call( self, From b414c08485dbac5ed0122cc2364675312b2def6e Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:49:54 -0800 Subject: [PATCH 27/40] Switch AnthropicChatModel to new parsing logic --- .../chat_model/anthropic_chat_model.py | 350 ++++-------------- 1 file changed, 78 insertions(+), 272 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index e94749e1..6448ee7a 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -1,7 +1,6 @@ import base64 import json from collections.abc import ( - AsyncIterable, AsyncIterator, Callable, Iterable, @@ -10,20 +9,17 @@ ) from enum import Enum from functools import singledispatch -from itertools import chain, groupby +from itertools import groupby from typing import Any, Generic, TypeVar, cast, overload import filetype -from pydantic import ValidationError from magentic.chat_model.base import ( ChatModel, - ToolSchemaParseError, - avalidate_str_content, - validate_str_content, + aparse_stream, + parse_stream, ) from magentic.chat_model.function_schema import ( - AsyncFunctionSchema, BaseFunctionSchema, FunctionCallFunctionSchema, FunctionSchema, @@ -39,6 +35,13 @@ UserMessage, _RawMessage, ) +from magentic.chat_model.stream import ( + AsyncOutputStream, + FunctionCallChunk, + OutputStream, + StreamParser, + StreamState, +) from magentic.function_call import ( AsyncParallelFunctionCall, FunctionCall, @@ -48,13 +51,6 @@ from magentic.streaming import ( AsyncStreamedStr, StreamedStr, - aapply, - achain, - agroupby, - apeek, - apply, - async_iter, - peek, ) from magentic.typing import is_any_origin_subclass, is_origin_subclass from magentic.vision import UserImageMessage @@ -63,13 +59,7 @@ import anthropic from anthropic.lib.streaming import MessageStreamEvent from anthropic.lib.streaming._messages import accumulate_event - from anthropic.types import ( - ContentBlockDeltaEvent, - ContentBlockStartEvent, - MessageParam, - ToolParam, - ToolUseBlock, - ) + from anthropic.types import MessageParam, ToolParam from anthropic.types.message_create_params import ToolChoice except ImportError as error: msg = "To use AnthropicChatModel you must install the `anthropic` package using `pip install 'magentic[anthropic]'`." @@ -241,127 +231,62 @@ def as_tool_choice(self) -> ToolChoice: return {"type": "tool", "name": self._function_schema.name} -# TODO: Generalize this to BaseToolSchema when that is created -BeseToolSchemaT = TypeVar("BeseToolSchemaT", bound=BaseFunctionToolSchema[Any]) - - -def select_tool_schema( - tool_call: ToolUseBlock, - tool_schemas: Iterable[BeseToolSchemaT], -) -> BeseToolSchemaT: - """Select the tool schema based on the response chunk.""" - for tool_schema in tool_schemas: - if tool_schema._function_schema.name == tool_call.name: - return tool_schema - - msg = f"Unknown tool call: {tool_call.model_dump_json()}" - raise ValueError(msg) +class AnthropicStreamParser(StreamParser[MessageStreamEvent]): + def is_content(self, item: MessageStreamEvent) -> bool: + return item.type == "content_block_delta" + def is_content_ended(self, item: MessageStreamEvent) -> bool: + return self.is_tool_call(item) -class FunctionToolSchema(BaseFunctionToolSchema[FunctionSchema[T]]): - def parse_tool_call(self, chunks: Iterable[MessageStreamEvent]) -> T: - return self._function_schema.parse_args( - chunk.delta.partial_json - for chunk in chunks - if chunk.type == "content_block_delta" - if chunk.delta.type == "input_json_delta" - ) - + def get_content(self, item: MessageStreamEvent) -> str | None: + if item.type == "text": + return item.text + return None -class AsyncFunctionToolSchema(BaseFunctionToolSchema[AsyncFunctionSchema[T]]): - async def aparse_tool_call(self, chunks: AsyncIterable[MessageStreamEvent]) -> T: - return await self._function_schema.aparse_args( - chunk.delta.partial_json - async for chunk in chunks - if chunk.type == "content_block_delta" - if chunk.delta.type == "input_json_delta" + def is_tool_call(self, item: MessageStreamEvent) -> bool: + return ( + item.type == "content_block_start" and item.content_block.type == "tool_use" ) + def iter_tool_calls(self, item: MessageStreamEvent) -> Iterable[FunctionCallChunk]: + if item.type == "content_block_start" and item.content_block.type == "tool_use": + return [ + FunctionCallChunk( + id=item.content_block.id, name=item.content_block.name, args=None + ) + ] + if item.type == "input_json": + return [FunctionCallChunk(id=None, name=None, args=item.partial_json)] + return [] -def _iter_streamed_tool_calls( - response: Iterable[MessageStreamEvent], -) -> Iterator[Iterator[ContentBlockStartEvent | ContentBlockDeltaEvent]]: - all_tool_call_chunks = ( - cast(ContentBlockStartEvent | ContentBlockDeltaEvent, chunk) - for chunk in response - if chunk.type in ("content_block_start", "content_block_delta") - ) - for _, tool_call_chunks in groupby(all_tool_call_chunks, lambda x: x.index): - yield tool_call_chunks +class AnthropicStreamState(StreamState[MessageStreamEvent]): + def __init__(self): + self._current_message_snapshot: anthropic.types.Message | None = ( + None # TODO: type + ) + self.usage_ref: list[Usage] = [] -async def _aiter_streamed_tool_calls( - response: AsyncIterable[MessageStreamEvent], -) -> AsyncIterator[AsyncIterator[ContentBlockStartEvent | ContentBlockDeltaEvent]]: - all_tool_call_chunks = ( - cast(ContentBlockStartEvent | ContentBlockDeltaEvent, chunk) - async for chunk in response - if chunk.type in ("content_block_start", "content_block_delta") - ) - async for _, tool_call_chunks in agroupby(all_tool_call_chunks, lambda x: x.index): - yield tool_call_chunks - - -def _join_streamed_response_to_message( - response: list[MessageStreamEvent], -) -> _RawMessage[MessageParam]: - snapshot = None - for event in response: - snapshot = accumulate_event( - event=event, # type: ignore[arg-type] - current_snapshot=snapshot, + def update(self, item: MessageStreamEvent) -> None: + self._current_message_snapshot = accumulate_event( + # Unrecognized event types are ignored + event=item, # type: ignore[arg-type] + current_snapshot=self._current_message_snapshot, ) - assert snapshot is not None # noqa: S101 - snapshot_content = snapshot.model_dump()["content"] - return _RawMessage({"role": snapshot.role, "content": snapshot_content}) - - -def _parse_streamed_tool_calls( - response: Iterable[MessageStreamEvent], - tool_schemas: Iterable[FunctionToolSchema[T]], -) -> Iterator[T]: - cached_response: list[MessageStreamEvent] = [] - response = apply(cached_response.append, response) - try: - for tool_call_chunks in _iter_streamed_tool_calls(response): - first_chunk, tool_call_chunks = peek(tool_call_chunks) - assert first_chunk.type == "content_block_start" # noqa: S101 - assert first_chunk.content_block.type == "tool_use" # noqa: S101 - tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas) - tool_call = tool_schema.parse_tool_call(tool_call_chunks) - yield tool_call - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: - raw_message = _join_streamed_response_to_message(cached_response) - raise ToolSchemaParseError( - output_message=raw_message, - tool_call_id=raw_message.content["content"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e - - -async def _aparse_streamed_tool_calls( - response: AsyncIterable[MessageStreamEvent], - tool_schemas: Iterable[AsyncFunctionToolSchema[T]], -) -> AsyncIterator[T]: - cached_response: list[MessageStreamEvent] = [] - response = aapply(cached_response.append, response) - try: - async for tool_call_chunks in _aiter_streamed_tool_calls(response): - first_chunk, tool_call_chunks = await apeek(tool_call_chunks) - assert first_chunk.type == "content_block_start" # noqa: S101 - assert first_chunk.content_block.type == "tool_use" # noqa: S101 - tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas) - tool_call = await tool_schema.aparse_tool_call(tool_call_chunks) - yield tool_call - # TODO: Catch/raise unknown tool call error here - except ValidationError as e: - raw_message = _join_streamed_response_to_message(cached_response) - raise ToolSchemaParseError( - output_message=raw_message, - tool_call_id=raw_message.content["content"][0]["id"], # type: ignore[index,unused-ignore] - validation_error=e, - ) from e + if item.type == "message_stop": + assert not self.usage_ref # noqa: S101 + self.usage_ref.append( + Usage( + input_tokens=item.message.usage.input_tokens, + output_tokens=item.message.usage.output_tokens, + ) + ) + + @property + def current_message_snapshot(self) -> Message: + assert self._current_message_snapshot is not None # noqa: S101 + # TODO: Possible to return AssistantMessage here? + return _RawMessage(self._current_message_snapshot.model_dump()) def _extract_system_message( @@ -377,62 +302,6 @@ def _extract_system_message( ) -def _create_usage_ref( - response: Iterable[MessageStreamEvent], -) -> tuple[list[Usage], Iterator[MessageStreamEvent]]: - """Returns a pointer to a Usage object that is created at the end of the response.""" - usage_ref: list[Usage] = [] - - def generator( - response: Iterable[MessageStreamEvent], - ) -> Iterator[MessageStreamEvent]: - message_start_usage = None - output_tokens = None - for chunk in response: - if chunk.type == "message_start": - message_start_usage = chunk.message.usage - if chunk.type == "message_delta": - output_tokens = chunk.usage.output_tokens - yield chunk - if message_start_usage and output_tokens: - usage_ref.append( - Usage( - input_tokens=message_start_usage.input_tokens, - output_tokens=message_start_usage.output_tokens + output_tokens, - ) - ) - - return usage_ref, generator(response) - - -def _create_usage_ref_async( - response: AsyncIterable[MessageStreamEvent], -) -> tuple[list[Usage], AsyncIterator[MessageStreamEvent]]: - """Async version of `_create_usage_ref`.""" - usage_ref: list[Usage] = [] - - async def agenerator( - response: AsyncIterable[MessageStreamEvent], - ) -> AsyncIterator[MessageStreamEvent]: - message_start_usage = None - output_tokens = None - async for chunk in response: - if chunk.type == "message_start": - message_start_usage = chunk.message.usage - if chunk.type == "message_delta": - output_tokens = chunk.usage.output_tokens - yield chunk - if message_start_usage and output_tokens: - usage_ref.append( - Usage( - input_tokens=message_start_usage.input_tokens, - output_tokens=message_start_usage.output_tokens + output_tokens, - ) - ) - - return usage_ref, agenerator(response) - - def _if_given(value: T | None) -> T | anthropic.NotGiven: return value if value is not None else anthropic.NOT_GIVEN @@ -544,7 +413,7 @@ def complete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [FunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) @@ -566,47 +435,15 @@ def complete( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), ).__enter__() - - usage_ref, response = _create_usage_ref(response) - - message_start_chunk = next(response) - assert message_start_chunk.type == "message_start" # noqa: S101 - first_chunk = next(response) - assert first_chunk.type == "content_block_start" # noqa: S101 - response = chain([message_start_chunk, first_chunk], response) - - if ( - first_chunk.type == "content_block_start" - and first_chunk.content_block.type == "text" - ): - streamed_str = StreamedStr( - chunk.delta.text - for chunk in response - if chunk.type == "content_block_delta" - and chunk.delta.type == "text_delta" - ) - str_content = validate_str_content( - streamed_str, - allow_string_output=allow_string_output, - streamed=streamed_str_in_output_types, - ) - return AssistantMessage._with_usage(str_content, usage_ref) # type: ignore[return-value] - - if ( - first_chunk.type == "content_block_start" - and first_chunk.content_block.type == "tool_use" - ): - tool_calls = _parse_streamed_tool_calls(response, tool_schemas) - if is_any_origin_subclass(output_types, ParallelFunctionCall): - content = ParallelFunctionCall(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - # TODO: Create generator here that raises error or warns if multiple tool_calls - content = next(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] - - msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}" - raise ValueError(msg) + stream = OutputStream( + response, + function_schemas=function_schemas, + parser=AnthropicStreamParser(), + state=AnthropicStreamState(), + ) + return AssistantMessage._with_usage( + parse_stream(stream, output_types), usage_ref=stream.usage_ref + ) @overload async def acomplete( @@ -645,7 +482,7 @@ async def acomplete( for type_ in output_types if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) ] - tool_schemas = [AsyncFunctionToolSchema(schema) for schema in function_schemas] + tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) async_streamed_str_in_output_types = is_any_origin_subclass( @@ -671,43 +508,12 @@ async def acomplete( tool_schemas=tool_schemas, allow_string_output=allow_string_output ), ).__aenter__() - usage_ref, response = _create_usage_ref_async(response) - - message_start_chunk = await anext(response) - assert message_start_chunk.type == "message_start" # noqa: S101 - first_chunk = await anext(response) - assert first_chunk.type == "content_block_start" # noqa: S101 - response = achain(async_iter([message_start_chunk, first_chunk]), response) - - if ( - first_chunk.type == "content_block_start" - and first_chunk.content_block.type == "text" - ): - async_streamed_str = AsyncStreamedStr( - chunk.delta.text - async for chunk in response - if chunk.type == "content_block_delta" - and chunk.delta.type == "text_delta" - ) - str_content = await avalidate_str_content( - async_streamed_str, - allow_string_output=allow_string_output, - streamed=async_streamed_str_in_output_types, - ) - return AssistantMessage._with_usage(str_content, usage_ref) # type: ignore[return-value] - - if ( - first_chunk.type == "content_block_start" - and first_chunk.content_block.type == "tool_use" - ): - tool_calls = _aparse_streamed_tool_calls(response, tool_schemas) - if is_any_origin_subclass(output_types, AsyncParallelFunctionCall): - content = AsyncParallelFunctionCall(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] - # Take only the first tool_call, silently ignore extra chunks - # TODO: Create generator here that raises error or warns if multiple tool_calls - content = await anext(tool_calls) - return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value] - - msg = "Could not determine response type" - raise ValueError(msg) + stream = AsyncOutputStream( + response, + function_schemas=function_schemas, + parser=AnthropicStreamParser(), + state=AnthropicStreamState(), + ) + return AssistantMessage._with_usage( + await aparse_stream(stream, output_types), usage_ref=stream.usage_ref + ) From f5070275e8d9246e39a97c25656f13f22af01f38 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:51:46 -0800 Subject: [PATCH 28/40] Delete unused validate_str_content functions --- src/magentic/chat_model/base.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 8d153753..1fa432a8 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -58,32 +58,6 @@ def __init__( self.validation_error = validation_error -# TODO: Delete this function -def validate_str_content( - streamed_str: StreamedStr, *, allow_string_output: bool, streamed: bool -) -> StreamedStr | str: - """Raise error if string output not expected. Otherwise return correct string type.""" - if not allow_string_output: - model_output = streamed_str.truncate(100) - raise StringNotAllowedError(AssistantMessage(model_output)) - if streamed: - return streamed_str - return str(streamed_str) - - -# TODO: Delete this function -async def avalidate_str_content( - async_streamed_str: AsyncStreamedStr, *, allow_string_output: bool, streamed: bool -) -> AsyncStreamedStr | str: - """Async version of `validate_str_content`.""" - if not allow_string_output: - model_output = await async_streamed_str.truncate(100) - raise StringNotAllowedError(AssistantMessage(model_output)) - if streamed: - return async_streamed_str - return await async_streamed_str.to_string() - - # TODO: Make this a stream class with a close method and context management def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R: """Parse and validate the LLM output stream against the allowed output types.""" From 8f8e3d4dde1994ab9c98f7e19392424ab830b38e Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:57:34 -0800 Subject: [PATCH 29/40] Remove redundant TODOs --- src/magentic/chat_model/stream.py | 2 -- src/magentic/typing.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 13791cd1..becd24f5 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -122,7 +122,6 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: if self._parser.is_content(current_item): stream = chain([current_item], stream) yield StreamedStr(self._streamed_str(stream, current_item_ref)) - # TODO: Make is_tool_calls to handle multiple tools elif self._parser.is_tool_call(current_item): tool_calls_stream = ( tool_call_chunk @@ -223,7 +222,6 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: if self._parser.is_content(current_item): stream = achain(async_iter([current_item]), stream) yield AsyncStreamedStr(self._streamed_str(stream, current_item_ref)) - # TODO: Make is_tool_calls to handle multiple tools elif self._parser.is_tool_call(current_item): tool_calls_stream = ( tool_call_chunk diff --git a/src/magentic/typing.py b/src/magentic/typing.py index efec30c7..8df9a23b 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -52,7 +52,6 @@ def is_instance_origin( return isinstance(obj, cls_or_tuple_origin) -# TODO: Remove once unused def is_any_origin_subclass( types: Iterable[type], cls_or_tuple: TypeT | tuple[TypeT, ...] ) -> bool: From 647e3624efc16fee875974d2073866a72f9332fc Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Wed, 27 Nov 2024 00:12:20 -0800 Subject: [PATCH 30/40] Fix prompt_chain and unskip tests --- src/magentic/prompt_chain.py | 6 +- .../test_async_prompt_chain.yaml | 93 +++++++++++-------- .../test_prompt_chain/test_prompt_chain.yaml | 68 +++++++------- tests/test_prompt_chain.py | 2 - 4 files changed, 92 insertions(+), 77 deletions(-) diff --git a/src/magentic/prompt_chain.py b/src/magentic/prompt_chain.py index dda2e7fc..f7a4c502 100644 --- a/src/magentic/prompt_chain.py +++ b/src/magentic/prompt_chain.py @@ -37,7 +37,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: async_prompt_function = AsyncPromptFunction[P, Any]( name=func.__name__, parameters=list(func_signature.parameters.values()), - return_type=func_signature.return_annotation, + # TODO: Also allow ParallelFunctionCall. Support this more neatly + return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type] template=template, functions=functions, model=model, @@ -70,7 +71,8 @@ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any: prompt_function = PromptFunction[P, R]( name=func.__name__, parameters=list(func_signature.parameters.values()), - return_type=func_signature.return_annotation, + # TODO: Also allow ParallelFunctionCall. Support this more neatly + return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type] template=template, functions=functions, model=model, diff --git a/tests/cassettes/test_prompt_chain/test_async_prompt_chain.yaml b/tests/cassettes/test_prompt_chain/test_async_prompt_chain.yaml index ed49702e..837bded5 100644 --- a/tests/cassettes/test_prompt_chain/test_async_prompt_chain.yaml +++ b/tests/cassettes/test_prompt_chain/test_async_prompt_chain.yaml @@ -42,28 +42,43 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_ToRc3P23Nq9mz68k7QXKvG7y","type":"function","function":{"name":"get_current_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_OdGRePgCvnJSeHHjYuviwBrB","type":"function","function":{"name":"get_current_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\",\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"unit"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDJeef7lugHCWP856cRKsxN6wzr","object":"chat.completion.chunk","created":1732409789,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[],"usage":{"prompt_tokens":70,"completion_tokens":15,"total_tokens":85,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"fahren"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"heit"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + + + data: {"id":"chatcmpl-AY7NvhQY8twSiXbZnXoHygtw9OZvK","object":"chat.completion.chunk","created":1732694903,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_831e067d82","choices":[],"usage":{"prompt_tokens":70,"completion_tokens":20,"total_tokens":90,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -74,13 +89,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e7570824d451742-SJC + - 8e90a14a0c8d22d2-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 24 Nov 2024 00:56:30 GMT + - Wed, 27 Nov 2024 08:08:23 GMT Server: - cloudflare Transfer-Encoding: @@ -92,7 +107,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '429' + - '369' openai-version: - '2020-10-01' strict-transport-security: @@ -104,13 +119,13 @@ interactions: x-ratelimit-remaining-requests: - '499' x-ratelimit-remaining-tokens: - - '29961' + - '29974' x-ratelimit-reset-requests: - 120ms x-ratelimit-reset-tokens: - - 77ms + - 52ms x-request-id: - - req_82bb4d277af45ba184f65ba214a706e8 + - req_29acff04284aaa4e4342f2af54d4a4b2 status: code: 200 message: OK @@ -118,8 +133,8 @@ interactions: body: '{"messages": [{"role": "user", "content": "What''s the weather like in Boston?"}, {"role": "assistant", "content": null, "tool_calls": [{"id": "000000000", "type": "function", "function": {"name": "get_current_weather", "arguments": - "{\"location\":\"Boston\"}"}}]}, {"role": "tool", "tool_call_id": "000000000", - "content": "{\"location\":\"Boston\",\"temperature\":\"72\",\"unit\":\"fahrenheit\",\"forecast\":[\"sunny\",\"windy\"]}"}], + "{\"location\":\"Boston\",\"unit\":\"fahrenheit\"}"}}]}, {"role": "tool", "tool_call_id": + "000000000", "content": "{\"location\":\"Boston\",\"temperature\":\"72\",\"unit\":\"fahrenheit\",\"forecast\":[\"sunny\",\"windy\"]}"}], "model": "gpt-4o", "parallel_tool_calls": false, "stream": true, "stream_options": {"include_usage": true}, "tools": [{"type": "function", "function": {"name": "get_current_weather", "parameters": {"properties": {"location": {"title": "Location"}, @@ -133,7 +148,7 @@ interactions: connection: - keep-alive content-length: - - '845' + - '869' content-type: - application/json host: @@ -160,46 +175,46 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: "data: {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + string: "data: {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" current\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" weather\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" in\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" Boston\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" is\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" - \"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"72\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\xB0F\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + \"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"72\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\xB0F\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" with\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" sunny\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" and\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" windy\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" conditions\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDKT3iXncr2YaxeGAtmHxL7mBvO\",\"object\":\"chat.completion.chunk\",\"created\":1732409790,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[],\"usage\":{\"prompt_tokens\":117,\"completion_tokens\":16,\"total_tokens\":133,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7Nw0Ld9DsdOxrbVxns7DBN8eA9a\",\"object\":\"chat.completion.chunk\",\"created\":1732694904,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[],\"usage\":{\"prompt_tokens\":122,\"completion_tokens\":16,\"total_tokens\":138,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\ndata: [DONE]\n\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e757086cd80d00d-SJC + - 8e90a14fcc6b67c7-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 24 Nov 2024 00:56:30 GMT + - Wed, 27 Nov 2024 08:08:24 GMT Server: - cloudflare Transfer-Encoding: @@ -211,7 +226,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '217' + - '255' openai-version: - '2020-10-01' strict-transport-security: @@ -229,7 +244,7 @@ interactions: x-ratelimit-reset-tokens: - 100ms x-request-id: - - req_b48ce435e67d7bbfadffa112e5647e88 + - req_700e5a9d06cbc5721c78aa22d0fbce6d status: code: 200 message: OK diff --git a/tests/cassettes/test_prompt_chain/test_prompt_chain.yaml b/tests/cassettes/test_prompt_chain/test_prompt_chain.yaml index e8cc72c7..128a2f8d 100644 --- a/tests/cassettes/test_prompt_chain/test_prompt_chain.yaml +++ b/tests/cassettes/test_prompt_chain/test_prompt_chain.yaml @@ -42,28 +42,28 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: 'data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_hkxs9H7xFCzJ4aZkX3a1cCri","type":"function","function":{"name":"get_current_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + string: 'data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_P3DrUvqprybs0ewD94G0ate8","type":"function","function":{"name":"get_current_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} - data: {"id":"chatcmpl-AWvDIAUsQhMsrZNf6Y80Ycv2FfndY","object":"chat.completion.chunk","created":1732409788,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[],"usage":{"prompt_tokens":70,"completion_tokens":15,"total_tokens":85,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + data: {"id":"chatcmpl-AY7NtLye868hzjFznwHUkCMpHwVn6","object":"chat.completion.chunk","created":1732694901,"model":"gpt-4o-2024-08-06","system_fingerprint":"fp_7f6be3efb0","choices":[],"usage":{"prompt_tokens":70,"completion_tokens":15,"total_tokens":85,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} data: [DONE] @@ -74,13 +74,13 @@ interactions: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e75707a1cf29e6a-SJC + - 8e90a13d0e5d17e6-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 24 Nov 2024 00:56:28 GMT + - Wed, 27 Nov 2024 08:08:22 GMT Server: - cloudflare Transfer-Encoding: @@ -92,7 +92,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '315' + - '681' openai-version: - '2020-10-01' strict-transport-security: @@ -110,7 +110,7 @@ interactions: x-ratelimit-reset-tokens: - 52ms x-request-id: - - req_d2ba77cde731e6a09799f02ea452a5d3 + - req_0d1d0bcbe79f8af2fa00aab53e512b43 status: code: 200 message: OK @@ -160,46 +160,46 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: "data: {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + string: "data: {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" current\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" weather\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" in\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" Boston\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" is\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" - \"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"72\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\xB0F\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + \"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"72\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\xB0F\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" with\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" sunny\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" and\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" windy\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" conditions\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":null}\n\ndata: - {\"id\":\"chatcmpl-AWvDJw8GmT0OyqR7pIJ62gwE1yijz\",\"object\":\"chat.completion.chunk\",\"created\":1732409789,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_7f6be3efb0\",\"choices\":[],\"usage\":{\"prompt_tokens\":117,\"completion_tokens\":16,\"total_tokens\":133,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":null}\n\ndata: + {\"id\":\"chatcmpl-AY7NutD05eI0kpgHRMnrDqTtYnvMW\",\"object\":\"chat.completion.chunk\",\"created\":1732694902,\"model\":\"gpt-4o-2024-08-06\",\"system_fingerprint\":\"fp_a7d06e42a7\",\"choices\":[],\"usage\":{\"prompt_tokens\":117,\"completion_tokens\":16,\"total_tokens\":133,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}}}\n\ndata: [DONE]\n\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8e75707daabacf2e-SJC + - 8e90a1446b19270c-SJC Connection: - keep-alive Content-Type: - text/event-stream; charset=utf-8 Date: - - Sun, 24 Nov 2024 00:56:29 GMT + - Wed, 27 Nov 2024 08:08:22 GMT Server: - cloudflare Transfer-Encoding: @@ -211,7 +211,7 @@ interactions: alt-svc: - h3=":443"; ma=86400 openai-processing-ms: - - '358' + - '421' openai-version: - '2020-10-01' strict-transport-security: @@ -229,7 +229,7 @@ interactions: x-ratelimit-reset-tokens: - 100ms x-request-id: - - req_646939d1b32780ef61665e5b65261787 + - req_039b14c0ae6a88154cd0e379aeae92e0 status: code: 200 message: OK diff --git a/tests/test_prompt_chain.py b/tests/test_prompt_chain.py index 69d85665..63796f41 100644 --- a/tests/test_prompt_chain.py +++ b/tests/test_prompt_chain.py @@ -7,7 +7,6 @@ from magentic.prompt_chain import MaxFunctionCallsError, prompt_chain -@pytest.mark.skip("TODO: Add FunctionCall to output_types internal to prompt_chain") @pytest.mark.openai def test_prompt_chain(): def get_current_weather(location, unit="fahrenheit"): @@ -51,7 +50,6 @@ def make_function_call() -> str: ... assert mock_function.call_count == 1 -@pytest.mark.skip("TODO: Add FunctionCall to output_types internal to prompt_chain") @pytest.mark.openai async def test_async_prompt_chain(): async def get_current_weather(location, unit="fahrenheit"): From ac4f27ce14d20b76bff669b0eacfeed8eeec1e18 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Wed, 27 Nov 2024 00:23:19 -0800 Subject: [PATCH 31/40] Only yield tool call args if not falsy --- src/magentic/chat_model/litellm_chat_model.py | 2 +- src/magentic/chat_model/stream.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 742be670..1a65b40b 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -95,7 +95,7 @@ def update(self, item: ModelResponse) -> None: usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined] # Ignore usages with 0 tokens if usage and usage.prompt_tokens and usage.completion_tokens: - # assert not self.usage_ref + assert not self.usage_ref # noqa: S101 self.usage_ref.append( Usage( input_tokens=usage.prompt_tokens, diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index becd24f5..fb7d7e1d 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -112,7 +112,8 @@ def _tool_call( assert not current_tool_call_ref # noqa: S101 current_tool_call_ref.append(item) return - yield item.args or "" + if item.args: + yield item.args def __stream__(self) -> Iterator[StreamedStr | OutputT]: stream = apply(self._state.update, self._stream) @@ -212,7 +213,8 @@ async def _tool_call( assert not current_tool_call_ref # noqa: S101 current_tool_call_ref.append(item) return - yield item.args or "" + if item.args: + yield item.args async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: stream = aapply(self._state.update, self._stream) From 56974ef06e097663fa991598b987c7415421249d Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Wed, 27 Nov 2024 00:43:35 -0800 Subject: [PATCH 32/40] Add FunctionCallNotAllowedError, ObjectNotAllowedError --- src/magentic/chat_model/base.py | 54 ++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 1fa432a8..86c2f867 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -24,17 +24,45 @@ # TODO: Parent class with `output_message` attribute ? class StringNotAllowedError(Exception): - """Raised when a string is returned by the LLM but not expected.""" + """Raised when a string is returned by the LLM but not allowed.""" _MESSAGE = ( - "A string was returned by the LLM but was not an allowed output type." - ' Consider updating the prompt to encourage the LLM to "use the tool".' + "A string was returned by the LLM but is not an allowed output type." + " Consider updating the allowed output types or modifying the prompt." " Model output: {model_output!r}" ) - def __init__(self, output_message: Message[Any]): - super().__init__(self._MESSAGE.format(model_output=output_message.content)) - self.output_message = output_message + def __init__(self, model_output: str): + super().__init__(self._MESSAGE.format(model_output=model_output)) + self.output_message = AssistantMessage(model_output) + + +class FunctionCallNotAllowedError(Exception): + """Raised when a FunctionCall is returned by the LLM but not allowed.""" + + _MESSAGE = ( + "A function call was returned by the LLM but is not an allowed output type." + " Consider updating the allowed output types or modifying the prompt." + " FunctionCall: {function_call!r}" + ) + + def __init__(self, function_call: FunctionCall[Any]): + super().__init__(self._MESSAGE.format(function_call=function_call)) + self.output_message = AssistantMessage(function_call) + + +class ObjectNotAllowedError(Exception): + """Raised when a Python object is returned by the LLM but not allowed.""" + + _MESSAGE = ( + "An object was returned by the LLM but is not an allowed output type." + " Consider updating the allowed output types or modifying the prompt." + " Object: {obj!r}" + ) + + def __init__(self, obj: Any): + super().__init__(self._MESSAGE.format(obj=obj)) + self.output_message = AssistantMessage(obj) class ToolSchemaParseError(Exception): @@ -71,18 +99,17 @@ def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R: return cast(R, obj) if str in output_type_origins: return cast(R, str(obj)) - model_output = obj.truncate(100) - raise StringNotAllowedError(AssistantMessage(model_output)) + raise StringNotAllowedError(obj.truncate(100)) if isinstance(obj, FunctionCall): if ParallelFunctionCall in output_type_origins: return cast(R, ParallelFunctionCall(chain([obj], stream))) if FunctionCall in output_type_origins: # TODO: Check that FunctionCall type matches ? return cast(R, obj) - raise ValueError("FunctionCall not allowed") + raise FunctionCallNotAllowedError(obj) if isinstance(obj, tuple(output_type_origins)): return obj - raise ValueError(f"Unexpected output type: {type(obj)}") + raise ObjectNotAllowedError(obj) async def aparse_stream( @@ -96,17 +123,16 @@ async def aparse_stream( return cast(R, obj) if str in output_type_origins: return cast(R, await obj.to_string()) - model_output = await obj.truncate(100) - raise StringNotAllowedError(AssistantMessage(model_output)) + raise StringNotAllowedError(await obj.truncate(100)) if isinstance(obj, FunctionCall): if AsyncParallelFunctionCall in output_type_origins: return cast(R, AsyncParallelFunctionCall(achain(async_iter([obj]), stream))) if FunctionCall in output_type_origins: return cast(R, obj) - raise ValueError("FunctionCall not allowed") + raise FunctionCallNotAllowedError(obj) if isinstance(obj, tuple(output_type_origins)): return obj - raise ValueError(f"Unexpected output type: {type(obj)}") + raise ObjectNotAllowedError(obj) class ChatModel(ABC): From c4b73e91097e214fb4256bbeec5ad852300ad627 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:09:38 -0800 Subject: [PATCH 33/40] Add UnknownToolError, raise in OutputStream --- src/magentic/chat_model/base.py | 14 ++++++++++++++ src/magentic/chat_model/function_schema.py | 5 ++--- src/magentic/chat_model/retry_chat_model.py | 4 ++++ src/magentic/chat_model/stream.py | 16 +++++++++++++++- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 86c2f867..b53e0819 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -65,6 +65,20 @@ def __init__(self, obj: Any): self.output_message = AssistantMessage(obj) +class UnknownToolError(Exception): + """Raised when the LLM returns a tool call for an unknown tool.""" + + _MESSAGE = ( + "The LLM returned a tool call for a tool name that is not recognized." + " Tool name: {tool_name!r}" + ) + + def __init__(self, output_message: Message, tool_call_id: str, tool_name: str): + super().__init__(self._MESSAGE.format(tool_name=tool_name)) + self.output_message = output_message + self.tool_call_id = tool_call_id + + class ToolSchemaParseError(Exception): """Raised when the LLM output could not be parsed by the tool schema.""" diff --git a/src/magentic/chat_model/function_schema.py b/src/magentic/chat_model/function_schema.py index 40e27785..f104b717 100644 --- a/src/magentic/chat_model/function_schema.py +++ b/src/magentic/chat_model/function_schema.py @@ -64,13 +64,12 @@ def dict(self) -> FunctionDefinition: def select_function_schema( function_schemas: Iterable[BaseFunctionSchemaT], name: str -) -> BaseFunctionSchemaT: +) -> BaseFunctionSchemaT | None: """Select the function schema with the given name.""" for schema in function_schemas: if schema.name == name: return schema - # TODO: Catch/raise unknown tool call error here - raise ValueError(f"No function schema found for name {name}") + return None class AsyncFunctionSchema(BaseFunctionSchema[T], Generic[T]): diff --git a/src/magentic/chat_model/retry_chat_model.py b/src/magentic/chat_model/retry_chat_model.py index 8d008c3d..a5c901b6 100644 --- a/src/magentic/chat_model/retry_chat_model.py +++ b/src/magentic/chat_model/retry_chat_model.py @@ -29,10 +29,14 @@ def __init__( self._max_retries = max_retries # TODO: Make this public to allow modifying error handling behavior + # User should be able to add handlers to instance using decorator + # e.g. `@my_retry_chat_model.exception_handler(exc_type)` + # TODO: Add exception base class for those with output_message attribute @singledispatchmethod def _make_retry_messages(self, error: Exception) -> list[Message[Any]]: raise NotImplementedError + # TODO: Catch UnknownToolError here @_make_retry_messages.register def _(self, error: ToolSchemaParseError) -> list[Message[Any]]: return [ diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index fb7d7e1d..1e515917 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -6,7 +6,7 @@ from litellm.llms.files_apis.azure import Any from pydantic import ValidationError -from magentic.chat_model.base import ToolSchemaParseError +from magentic.chat_model.base import ToolSchemaParseError, UnknownToolError from magentic.chat_model.function_schema import FunctionSchema, select_function_schema from magentic.chat_model.message import Message, Usage from magentic.streaming import ( @@ -138,6 +138,13 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: function_schema = select_function_schema( self._function_schemas, current_tool_call_chunk.name ) + if function_schema is None: + assert current_tool_call_id is not None # noqa: S101 + raise UnknownToolError( + output_message=self._state.current_message_snapshot, + tool_call_id=current_tool_call_id, + tool_name=current_tool_call_chunk.name, + ) try: yield function_schema.parse_args( self._tool_call( @@ -239,6 +246,13 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: function_schema = select_function_schema( self._function_schemas, current_tool_call_chunk.name ) + if function_schema is None: + assert current_tool_call_id is not None # noqa: S101 + raise UnknownToolError( + output_message=self._state.current_message_snapshot, + tool_call_id=current_tool_call_id, + tool_name=current_tool_call_chunk.name, + ) try: yield await function_schema.aparse_args( self._tool_call( From b964aad1637421d461a4932b019b3783aeed06bc Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:15:24 -0800 Subject: [PATCH 34/40] Remove done todo for unknown tool call --- src/magentic/chat_model/stream.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 1e515917..ec1e2632 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -153,7 +153,6 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: current_tool_call_id, ) ) - # TODO: Catch/raise unknown tool call error here except ValidationError as e: assert current_tool_call_id is not None # noqa: S101 raise ToolSchemaParseError( @@ -264,7 +263,6 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: current_tool_call_id, ) ) - # TODO: Catch/raise unknown tool call error here except ValidationError as e: assert current_tool_call_id is not None # noqa: S101 raise ToolSchemaParseError( From 6b408fd9b6863d328bb41f6fc7aa6a827acdc412 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:20:49 -0800 Subject: [PATCH 35/40] Remove is_content_ended if favor of is_tool_call --- src/magentic/chat_model/anthropic_chat_model.py | 3 --- src/magentic/chat_model/litellm_chat_model.py | 3 --- src/magentic/chat_model/openai_chat_model.py | 3 --- src/magentic/chat_model/stream.py | 10 +++------- 4 files changed, 3 insertions(+), 16 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 6448ee7a..c99367d6 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -235,9 +235,6 @@ class AnthropicStreamParser(StreamParser[MessageStreamEvent]): def is_content(self, item: MessageStreamEvent) -> bool: return item.type == "content_block_delta" - def is_content_ended(self, item: MessageStreamEvent) -> bool: - return self.is_tool_call(item) - def get_content(self, item: MessageStreamEvent) -> str | None: if item.type == "text": return item.text diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 1a65b40b..c13dec54 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -53,9 +53,6 @@ def is_content(self, item: ModelResponse) -> bool: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 return bool(item.choices[0].delta.content) - def is_content_ended(self, item: ModelResponse) -> bool: - return self.is_tool_call(item) - def get_content(self, item: ModelResponse) -> str | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 return item.choices[0].delta.content diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index b6788a77..acbdb6a8 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -239,9 +239,6 @@ class OpenaiStreamParser(StreamParser[ChatCompletionChunk]): def is_content(self, item: ChatCompletionChunk) -> bool: return bool(item.choices and item.choices[0].delta.content) - def is_content_ended(self, item: ChatCompletionChunk) -> bool: - return self.is_tool_call(item) - def get_content(self, item: ChatCompletionChunk) -> str | None: if item.choices and item.choices[0].delta.content: return item.choices[0].delta.content diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index ec1e2632..4b67f899 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterable, Iterator from itertools import chain -from typing import Generic, NamedTuple, TypeVar +from typing import Any, Generic, NamedTuple, TypeVar -from litellm.llms.files_apis.azure import Any from pydantic import ValidationError from magentic.chat_model.base import ToolSchemaParseError, UnknownToolError @@ -32,9 +31,6 @@ class StreamParser(ABC, Generic[ItemT]): @abstractmethod def is_content(self, item: ItemT) -> bool: ... - @abstractmethod - def is_content_ended(self, item: ItemT) -> bool: ... - @abstractmethod def get_content(self, item: ItemT) -> str | None: ... @@ -92,7 +88,7 @@ def _streamed_str( for item in stream: if content := self._parser.get_content(item): yield content - if self._parser.is_content_ended(item): + if self._parser.is_tool_call(item): # TODO: Check if output types allow for early return and raise if not assert not current_item_ref # noqa: S101 current_item_ref.append(item) @@ -201,7 +197,7 @@ async def _streamed_str( async for item in stream: if content := self._parser.get_content(item): yield content - if self._parser.is_content_ended(item): + if self._parser.is_tool_call(item): # TODO: Check if output types allow for early return assert not current_item_ref # noqa: S101 current_item_ref.append(item) From cb2e5657612832f5ad2d5802e8426be5eccda761 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:22:15 -0800 Subject: [PATCH 36/40] Add typecheck to make all --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b98bf4d7..11370df5 100644 --- a/Makefile +++ b/Makefile @@ -56,4 +56,4 @@ docs-serve: # Build and serve the documentation uv run mkdocs serve .PHONY: all -all: format lint test +all: format lint typecheck test From d6c7bfa24a07b348132e595f8d59c4e770dca3ba Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:06:25 -0800 Subject: [PATCH 37/40] Add get_function_schemas. Fix some mypy errors --- .../chat_model/anthropic_chat_model.py | 31 +++--------- src/magentic/chat_model/base.py | 6 +-- src/magentic/chat_model/function_schema.py | 50 ++++++++++++++++++- src/magentic/chat_model/litellm_chat_model.py | 22 +++----- src/magentic/chat_model/openai_chat_model.py | 31 +++--------- src/magentic/chat_model/stream.py | 14 +++--- 6 files changed, 76 insertions(+), 78 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index c99367d6..8cbd4794 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -23,8 +23,9 @@ BaseFunctionSchema, FunctionCallFunctionSchema, FunctionSchema, - async_function_schema_for_type, function_schema_for_type, + get_async_function_schemas, + get_function_schemas, ) from magentic.chat_model.message import ( AssistantMessage, @@ -43,7 +44,6 @@ StreamState, ) from magentic.function_call import ( - AsyncParallelFunctionCall, FunctionCall, ParallelFunctionCall, _create_unique_id, @@ -52,7 +52,7 @@ AsyncStreamedStr, StreamedStr, ) -from magentic.typing import is_any_origin_subclass, is_origin_subclass +from magentic.typing import is_any_origin_subclass from magentic.vision import UserImageMessage try: @@ -280,7 +280,7 @@ def update(self, item: MessageStreamEvent) -> None: ) @property - def current_message_snapshot(self) -> Message: + def current_message_snapshot(self) -> Message[Any]: assert self._current_message_snapshot is not None # noqa: S101 # TODO: Possible to return AssistantMessage here? return _RawMessage(self._current_message_snapshot.model_dump()) @@ -306,16 +306,6 @@ def _if_given(value: T | None) -> T | anthropic.NotGiven: R = TypeVar("R") -STR_OR_FUNCTIONCALL_TYPE = ( - str, - StreamedStr, - AsyncStreamedStr, - FunctionCall, - ParallelFunctionCall, - AsyncParallelFunctionCall, -) - - class AnthropicChatModel(ChatModel): """An LLM chat model that uses the `anthropic` python package.""" @@ -404,12 +394,7 @@ def complete( if output_types is None: output_types = [] if functions else cast(list[type[R]], [str]) - # TODO: Check that Function calls types match functions - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) @@ -474,11 +459,7 @@ async def acomplete( if output_types is None: output_types = [] if functions else cast(list[type[R]], [str]) - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - async_function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index b53e0819..2575097b 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -73,7 +73,7 @@ class UnknownToolError(Exception): " Tool name: {tool_name!r}" ) - def __init__(self, output_message: Message, tool_call_id: str, tool_name: str): + def __init__(self, output_message: Message[Any], tool_call_id: str, tool_name: str): super().__init__(self._MESSAGE.format(tool_name=tool_name)) self.output_message = output_message self.tool_call_id = tool_call_id @@ -122,7 +122,7 @@ def parse_stream(stream: Iterator[Any], output_types: Iterable[type[R]]) -> R: return cast(R, obj) raise FunctionCallNotAllowedError(obj) if isinstance(obj, tuple(output_type_origins)): - return obj + return cast(R, obj) raise ObjectNotAllowedError(obj) @@ -145,7 +145,7 @@ async def aparse_stream( return cast(R, obj) raise FunctionCallNotAllowedError(obj) if isinstance(obj, tuple(output_type_origins)): - return obj + return cast(R, obj) raise ObjectNotAllowedError(obj) diff --git a/src/magentic/chat_model/function_schema.py b/src/magentic/chat_model/function_schema.py index f104b717..2d84c114 100644 --- a/src/magentic/chat_model/function_schema.py +++ b/src/magentic/chat_model/function_schema.py @@ -9,12 +9,18 @@ from pydantic import BaseModel, TypeAdapter, create_model from magentic._pydantic import ConfigDict, get_pydantic_config, json_schema -from magentic.function_call import FunctionCall +from magentic.function_call import ( + AsyncParallelFunctionCall, + FunctionCall, + ParallelFunctionCall, +) from magentic.streaming import ( + AsyncStreamedStr, + StreamedStr, aiter_streamed_json_array, iter_streamed_json_array, ) -from magentic.typing import is_origin_abstract, name_type +from magentic.typing import is_origin_abstract, is_origin_subclass, name_type T = TypeVar("T") @@ -444,3 +450,43 @@ def serialize_args(self, value: FunctionCall[T]) -> str: return self._model.model_construct(**value.arguments).model_dump_json( exclude_unset=True ) + + +R = TypeVar("R") + +_NON_FUNCTION_CALL_TYPES = ( + str, + StreamedStr, + AsyncStreamedStr, + FunctionCall, + ParallelFunctionCall, + AsyncParallelFunctionCall, +) + + +def get_function_schemas( + functions: Iterable[Callable[..., R]] | None, + output_types: Iterable[type[T]], +) -> Iterable[FunctionSchema[FunctionCall[R] | T]]: + return [ + *(FunctionCallFunctionSchema(f) for f in functions or []), # type: ignore[list-item] + *( + function_schema_for_type(type_) + for type_ in output_types + if not is_origin_subclass(type_, _NON_FUNCTION_CALL_TYPES) # type: ignore[list-item] + ), + ] + + +def get_async_function_schemas( + functions: Iterable[Callable[..., R]] | None, + output_types: Iterable[type[T]], +) -> Iterable[FunctionSchema[FunctionCall[R] | T]]: + return [ + *(FunctionCallFunctionSchema(f) for f in functions or []), # type: ignore[list-item] + *( + async_function_schema_for_type(type_) + for type_ in output_types + if not is_origin_subclass(type_, _NON_FUNCTION_CALL_TYPES) # type: ignore[list-item] + ), + ] diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index c13dec54..79b66963 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -12,9 +12,8 @@ parse_stream, ) from magentic.chat_model.function_schema import ( - FunctionCallFunctionSchema, - async_function_schema_for_type, - function_schema_for_type, + get_async_function_schemas, + get_function_schemas, ) from magentic.chat_model.message import ( AssistantMessage, @@ -23,7 +22,6 @@ _RawMessage, ) from magentic.chat_model.openai_chat_model import ( - STR_OR_FUNCTIONCALL_TYPE, BaseFunctionToolSchema, message_to_openai_message, ) @@ -38,7 +36,7 @@ AsyncStreamedStr, StreamedStr, ) -from magentic.typing import is_any_origin_subclass, is_origin_subclass +from magentic.typing import is_any_origin_subclass try: import litellm @@ -101,7 +99,7 @@ def update(self, item: ModelResponse) -> None: ) @property - def current_message_snapshot(self) -> Message: + def current_message_snapshot(self) -> Message[Any]: snapshot = self._chat_completion_stream_state.current_completion_snapshot message = snapshot.choices[0].message # Fix incorrectly concatenated role @@ -202,11 +200,7 @@ def complete( if output_types is None: output_types = cast(Iterable[type[R]], [] if functions else [str]) - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) @@ -270,11 +264,7 @@ async def acomplete( if output_types is None: output_types = cast(Iterable[type[R]], [] if functions else [str]) - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - async_function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index acbdb6a8..36d23ca4 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -30,8 +30,9 @@ BaseFunctionSchema, FunctionCallFunctionSchema, FunctionSchema, - async_function_schema_for_type, function_schema_for_type, + get_async_function_schemas, + get_function_schemas, ) from magentic.chat_model.message import ( AssistantMessage, @@ -59,7 +60,7 @@ AsyncStreamedStr, StreamedStr, ) -from magentic.typing import is_any_origin_subclass, is_origin_subclass +from magentic.typing import is_any_origin_subclass from magentic.vision import UserImageMessage @@ -82,7 +83,7 @@ def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam: assert isinstance(message.content, dict) # noqa: S101 assert "role" in message.content # noqa: S101 assert "content" in message.content # noqa: S101 - return message.content # type: ignore[no-any-return] + return cast(ChatCompletionMessageParam, message.content) @message_to_openai_message.register @@ -300,7 +301,7 @@ def update(self, item: ChatCompletionChunk) -> None: ) @property - def current_message_snapshot(self) -> Message: + def current_message_snapshot(self) -> Message[Any]: snapshot = self._chat_completion_stream_state.current_completion_snapshot message = snapshot.choices[0].message # TODO: Possible to return AssistantMessage here? @@ -311,15 +312,6 @@ def _if_given(value: T | None) -> T | openai.NotGiven: return value if value is not None else openai.NOT_GIVEN -STR_OR_FUNCTIONCALL_TYPE = ( - str, - StreamedStr, - AsyncStreamedStr, - FunctionCall, - ParallelFunctionCall, - AsyncParallelFunctionCall, -) - R = TypeVar("R") @@ -453,12 +445,7 @@ def complete( if output_types is None: output_types = cast(Iterable[type[R]], [] if functions else [str]) - # TODO: Check that Function calls types match functions - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] # TODO: pass output_types to _get_tool_choice directly and remove these @@ -527,11 +514,7 @@ async def acomplete( if output_types is None: output_types = [] if functions else cast(list[type[R]], [str]) - function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ - async_function_schema_for_type(type_) - for type_ in output_types - if not is_origin_subclass(type_, STR_OR_FUNCTIONCALL_TYPE) - ] + function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] str_in_output_types = is_any_origin_subclass(output_types, str) diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 4b67f899..1ca3c79b 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -6,7 +6,11 @@ from pydantic import ValidationError from magentic.chat_model.base import ToolSchemaParseError, UnknownToolError -from magentic.chat_model.function_schema import FunctionSchema, select_function_schema +from magentic.chat_model.function_schema import ( + AsyncFunctionSchema, + FunctionSchema, + select_function_schema, +) from magentic.chat_model.message import Message, Usage from magentic.streaming import ( AsyncStreamedStr, @@ -163,9 +167,6 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: def usage_ref(self) -> list[Usage]: return self._state.usage_ref - def close(self): - self._stream.close() - class AsyncOutputStream(Generic[ItemT, OutputT]): """Async version of `OutputStream`.""" @@ -173,7 +174,7 @@ class AsyncOutputStream(Generic[ItemT, OutputT]): def __init__( self, stream: AsyncIterator[ItemT], - function_schemas: Iterable[FunctionSchema[OutputT]], + function_schemas: Iterable[AsyncFunctionSchema[OutputT]], parser: StreamParser[ItemT], state: StreamState[ItemT], ): @@ -272,6 +273,3 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: @property def usage_ref(self) -> list[Usage]: return self._state.usage_ref - - async def close(self): - await self._stream.close() From a92217f031de17011471f0607f4fdbb090243e28 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:10:01 -0800 Subject: [PATCH 38/40] Tidy calculation of allow_string_output --- src/magentic/chat_model/anthropic_chat_model.py | 10 +++------- src/magentic/chat_model/litellm_chat_model.py | 10 +++------- src/magentic/chat_model/openai_chat_model.py | 11 +++-------- 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index 8cbd4794..a84b67a9 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -397,9 +397,7 @@ def complete( function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - str_in_output_types = is_any_origin_subclass(output_types, str) - streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) - allow_string_output = str_in_output_types or streamed_str_in_output_types + allow_string_output = is_any_origin_subclass(output_types, (str, StreamedStr)) system, messages = _extract_system_message(messages) @@ -462,11 +460,9 @@ async def acomplete( function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - str_in_output_types = is_any_origin_subclass(output_types, str) - async_streamed_str_in_output_types = is_any_origin_subclass( - output_types, AsyncStreamedStr + allow_string_output = is_any_origin_subclass( + output_types, (str, AsyncStreamedStr) ) - allow_string_output = str_in_output_types or async_streamed_str_in_output_types system, messages = _extract_system_message(messages) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 79b66963..5a10d037 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -203,9 +203,7 @@ def complete( function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - str_in_output_types = is_any_origin_subclass(output_types, str) - streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) - allow_string_output = str_in_output_types or streamed_str_in_output_types + allow_string_output = is_any_origin_subclass(output_types, (str, StreamedStr)) response = litellm.completion( model=self.model, @@ -267,11 +265,9 @@ async def acomplete( function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - str_in_output_types = is_any_origin_subclass(output_types, str) - async_streamed_str_in_output_types = is_any_origin_subclass( - output_types, AsyncStreamedStr + allow_string_output = is_any_origin_subclass( + output_types, (str, AsyncStreamedStr) ) - allow_string_output = str_in_output_types or async_streamed_str_in_output_types response = await litellm.acompletion( model=self.model, diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 36d23ca4..1caf8c5c 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -448,10 +448,7 @@ def complete( function_schemas = get_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - # TODO: pass output_types to _get_tool_choice directly and remove these - str_in_output_types = is_any_origin_subclass(output_types, str) - streamed_str_in_output_types = is_any_origin_subclass(output_types, StreamedStr) - allow_string_output = str_in_output_types or streamed_str_in_output_types + allow_string_output = is_any_origin_subclass(output_types, (str, StreamedStr)) response: Iterator[ChatCompletionChunk] = self._client.chat.completions.create( model=self.model, @@ -517,11 +514,9 @@ async def acomplete( function_schemas = get_async_function_schemas(functions, output_types) tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas] - str_in_output_types = is_any_origin_subclass(output_types, str) - async_streamed_str_in_output_types = is_any_origin_subclass( - output_types, AsyncStreamedStr + allow_string_output = is_any_origin_subclass( + output_types, (str, AsyncStreamedStr) ) - allow_string_output = str_in_output_types or async_streamed_str_in_output_types response: AsyncIterator[ ChatCompletionChunk From 846be9ea8f7bd017865ec0e4eda14d48b1eff713 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:21:24 -0800 Subject: [PATCH 39/40] Fix remaining mypy errors --- .../chat_model/anthropic_chat_model.py | 2 +- src/magentic/chat_model/litellm_chat_model.py | 33 ++++++++----------- src/magentic/chat_model/openai_chat_model.py | 5 +-- src/magentic/prompt_chain.py | 4 +-- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index a84b67a9..b4ce4e7d 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -258,7 +258,7 @@ def iter_tool_calls(self, item: MessageStreamEvent) -> Iterable[FunctionCallChun class AnthropicStreamState(StreamState[MessageStreamEvent]): - def __init__(self): + def __init__(self) -> None: self._current_message_snapshot: anthropic.types.Message | None = ( None # TODO: type ) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 5a10d037..1bbde01e 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -1,26 +1,16 @@ from collections.abc import Callable, Iterable, Sequence from typing import Any, Literal, TypeVar, cast, overload -import litellm import openai -from litellm.litellm_core_utils.streaming_handler import StreamingChoices from openai.lib.streaming.chat._completions import ChatCompletionStreamState +from openai.types.chat import ChatCompletionNamedToolChoiceParam -from magentic.chat_model.base import ( - ChatModel, - aparse_stream, - parse_stream, -) +from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream from magentic.chat_model.function_schema import ( get_async_function_schemas, get_function_schemas, ) -from magentic.chat_model.message import ( - AssistantMessage, - Message, - Usage, - _RawMessage, -) +from magentic.chat_model.message import AssistantMessage, Message, Usage, _RawMessage from magentic.chat_model.openai_chat_model import ( BaseFunctionToolSchema, message_to_openai_message, @@ -40,6 +30,9 @@ try: import litellm + from litellm.litellm_core_utils.streaming_handler import ( # type: ignore[attr-defined] + StreamingChoices, + ) from litellm.types.utils import ModelResponse except ImportError as error: msg = "To use LitellmChatModel you must install the `litellm` package using `pip install 'magentic[litellm]'`." @@ -53,6 +46,7 @@ def is_content(self, item: ModelResponse) -> bool: def get_content(self, item: ModelResponse) -> str | None: assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert isinstance(item.choices[0].delta.content, str | None) # noqa: S101 return item.choices[0].delta.content def is_tool_call(self, item: ModelResponse) -> bool: @@ -72,7 +66,7 @@ def iter_tool_calls(self, item: ModelResponse) -> Iterable[FunctionCallChunk]: class LitellmStreamState(StreamState[ModelResponse]): - def __init__(self): + def __init__(self) -> None: self._chat_completion_stream_state = ChatCompletionStreamState( input_tools=openai.NOT_GIVEN, response_format=openai.NOT_GIVEN, @@ -85,9 +79,10 @@ def update(self, item: ModelResponse) -> None: # litellm requires usage is not None for its total usage calculation item.usage = litellm.Usage() # type: ignore[attr-defined] if not hasattr(item, "refusal"): + assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 item.choices[0].delta.refusal = None # type: ignore[attr-defined] self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type] - usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined] + usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined,name-defined] # Ignore usages with 0 tokens if usage and usage.prompt_tokens and usage.completion_tokens: assert not self.usage_ref # noqa: S101 @@ -160,12 +155,12 @@ def _get_tool_choice( *, tool_schemas: Sequence[BaseFunctionToolSchema[Any]], allow_string_output: bool, - ) -> dict | Literal["none", "auto", "required"] | None: + ) -> ChatCompletionNamedToolChoiceParam | Literal["required"] | None: """Create the tool choice argument.""" if allow_string_output: return None if len(tool_schemas) == 1: - return tool_schemas[0].as_tool_choice() # type: ignore[return-value] + return tool_schemas[0].as_tool_choice() return "required" @overload @@ -219,7 +214,7 @@ def complete( tools=[schema.to_dict() for schema in tool_schemas] or None, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), + ), # type: ignore[arg-type,unused-ignore] ) assert not isinstance(response, ModelResponse) # noqa: S101 stream = OutputStream( @@ -283,7 +278,7 @@ async def acomplete( tools=[schema.to_dict() for schema in tool_schemas] or None, tool_choice=self._get_tool_choice( tool_schemas=tool_schemas, allow_string_output=allow_string_output - ), # type: ignore[arg-type] + ), # type: ignore[arg-type,unused-ignore] ) assert not isinstance(response, ModelResponse) # noqa: S101 stream = AsyncOutputStream( diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 1caf8c5c..9461c6fb 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -16,6 +16,7 @@ from openai.types.chat import ( ChatCompletionChunk, ChatCompletionMessageParam, + ChatCompletionNamedToolChoiceParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, @@ -229,7 +230,7 @@ class BaseFunctionToolSchema(Generic[BaseFunctionSchemaT]): def __init__(self, function_schema: BaseFunctionSchemaT): self._function_schema = function_schema - def as_tool_choice(self) -> ChatCompletionToolChoiceOptionParam: + def as_tool_choice(self) -> ChatCompletionNamedToolChoiceParam: return {"type": "function", "function": {"name": self._function_schema.name}} def to_dict(self) -> ChatCompletionToolParam: @@ -267,7 +268,7 @@ class OpenaiStreamState(StreamState[ChatCompletionChunk]): - stop reason """ - def __init__(self): + def __init__(self) -> None: self._chat_completion_stream_state = ChatCompletionStreamState( input_tools=openai.NOT_GIVEN, response_format=openai.NOT_GIVEN, diff --git a/src/magentic/prompt_chain.py b/src/magentic/prompt_chain.py index f7a4c502..ab9b7fa8 100644 --- a/src/magentic/prompt_chain.py +++ b/src/magentic/prompt_chain.py @@ -38,7 +38,7 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: name=func.__name__, parameters=list(func_signature.parameters.values()), # TODO: Also allow ParallelFunctionCall. Support this more neatly - return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type] + return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type,unused-ignore] template=template, functions=functions, model=model, @@ -72,7 +72,7 @@ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any: name=func.__name__, parameters=list(func_signature.parameters.values()), # TODO: Also allow ParallelFunctionCall. Support this more neatly - return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type] + return_type=func_signature.return_annotation | FunctionCall, # type: ignore[arg-type,unused-ignore] template=template, functions=functions, model=model, From 169a640e40fca96602a8c3df981e868274627ec1 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:53:26 -0800 Subject: [PATCH 40/40] Delete is_instance_origin --- src/magentic/typing.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/magentic/typing.py b/src/magentic/typing.py index 8df9a23b..21045e92 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -17,7 +17,6 @@ def is_union_type(type_: type) -> bool: return type_ is Union or type_ is types.UnionType -T = TypeVar("T") TypeT = TypeVar("TypeT", bound=type) @@ -40,18 +39,6 @@ def is_origin_subclass( return issubclass(get_origin(type_) or type_, cls_or_tuple) -def is_instance_origin( - obj: Any, cls_or_tuple: type[T] | tuple[type[T], ...] -) -> TypeGuard[T]: - """Check if the object is an instance of the origin(s) of the given type(s).""" - cls_or_tuple_origin = ( - tuple(get_origin(cls) or cls for cls in cls_or_tuple) - if isinstance(cls_or_tuple, tuple) - else get_origin(cls_or_tuple) or cls_or_tuple - ) - return isinstance(obj, cls_or_tuple_origin) - - def is_any_origin_subclass( types: Iterable[type], cls_or_tuple: TypeT | tuple[TypeT, ...] ) -> bool: