From 7d90a58f1e77e776d5682ccbb50c87d71eee99c4 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Mon, 12 Aug 2024 15:36:58 +0200 Subject: [PATCH] fix: support streaming_callback param in amazon bedrock generators (#927) * fix: support streaming_callback param in amazon bedrock generators * fix chat generator merge * reformat --------- Co-authored-by: Thomas Stadelmann --- .../generators/amazon_bedrock/adapters.py | 93 ++++----- .../amazon_bedrock/chat/adapters.py | 55 +++--- .../amazon_bedrock/chat/chat_generator.py | 53 +++--- .../generators/amazon_bedrock/generator.py | 91 ++++----- .../generators/amazon_bedrock/handlers.py | 33 ---- .../tests/test_chat_generator.py | 9 - .../amazon_bedrock/tests/test_generator.py | 179 ++++++++---------- 7 files changed, 217 insertions(+), 296 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index 7c7fdd7ce..8b5c2b530 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -1,8 +1,8 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional -from .handlers import TokenStreamingHandler +from haystack.dataclasses import StreamingChunk class BedrockModelAdapter(ABC): @@ -39,22 +39,24 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]: responses = [completion.lstrip() for completion in completions] return responses - def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: + def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: """ Extracts the responses from the Amazon Bedrock streaming response. :param stream: The streaming response from the Amazon Bedrock request. - :param stream_handler: The handler for the streaming response. + :param streaming_callback: The handler for the streaming response. :returns: A list of string responses. """ - tokens: List[str] = [] + streaming_chunks: List[StreamingChunk] = [] for event in stream: chunk = event.get("chunk") if chunk: decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(decoded_chunk) - tokens.append(stream_handler(token, event_data=decoded_chunk)) - responses = ["".join(tokens).lstrip()] + streaming_chunk: StreamingChunk = self._build_streaming_chunk(decoded_chunk) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) + + responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()] return responses def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: @@ -84,12 +86,12 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ @abstractmethod - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ @@ -150,17 +152,17 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L return [response_body["completion"]] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ if self.use_messages_api: - return chunk.get("delta", {}).get("text", "") + return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) - return chunk.get("completion", "") + return StreamingChunk(content=chunk.get("completion", ""), meta=chunk) class MistralAdapter(BedrockModelAdapter): @@ -199,17 +201,18 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ return [output.get("text", "") for output in response_body.get("outputs", [])] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ + content = "" chunk_list = chunk.get("outputs", []) if chunk_list: - return chunk_list[0].get("text", "") - return "" + content = chunk_list[0].get("text", "") + return StreamingChunk(content=content, meta=chunk) class CohereCommandAdapter(BedrockModelAdapter): @@ -254,14 +257,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [generation["text"] for generation in response_body["generations"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("text", "") + return StreamingChunk(content=chunk.get("text", ""), meta=chunk) class CohereCommandRAdapter(BedrockModelAdapter): @@ -313,15 +316,15 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [response_body["text"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ token: str = chunk.get("text", "") - return token + return StreamingChunk(content=token, meta=chunk) class AI21LabsJurassic2Adapter(BedrockModelAdapter): @@ -357,7 +360,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [completion["data"]["text"] for completion in response_body["completions"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: msg = "Streaming is not supported for AI21 Jurassic 2 models." raise NotImplementedError(msg) @@ -398,14 +401,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [result["outputText"] for result in response_body["results"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("outputText", "") + return StreamingChunk(content=chunk.get("outputText", ""), meta=chunk) class MetaLlamaAdapter(BedrockModelAdapter): @@ -442,11 +445,11 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ return [response_body["generation"]] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("generation", "") + return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 162100934..67e833f73 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -48,19 +48,18 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: return self._extract_messages_from_response(response_body) def get_stream_responses( - self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None] + self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: - tokens: List[str] = [] + streaming_chunks: List[StreamingChunk] = [] last_decoded_chunk: Dict[str, Any] = {} for event in stream: chunk = event.get("chunk") if chunk: last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(last_decoded_chunk) - stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only - stream_handler(stream_chunk) # callback the stream handler with StreamingChunk - tokens.append(token) - responses = ["".join(tokens).lstrip()] + streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) + streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk + streaming_chunks.append(streaming_chunk) + responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod @@ -142,12 +141,12 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List """ @abstractmethod - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ @@ -252,16 +251,16 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) return messages - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return chunk.get("delta", {}).get("text", "") - return "" + return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: """ @@ -425,17 +424,17 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) return messages - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ response_chunk = chunk.get("outputs", []) if response_chunk: - return response_chunk[0].get("text", "") - return "" + return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -543,11 +542,11 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List metadata = {k: v for (k, v) in response_body.items() if k != message_tag} return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("generation", "") + return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 7485a96c5..206cb0b9a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -156,18 +156,29 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: stacklevel=2, ) - def invoke(self, *args, **kwargs): + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ - Invokes the Amazon Bedrock LLM with the given parameters. The parameters are passed to the Amazon Bedrock - client. + Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - :param args: The positional arguments passed to the generator. - :param kwargs: The keyword arguments passed to the generator. - :returns: List of `ChatMessage` generated by LLM. + :param messages: The messages to generate a response to. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional generation keyword arguments passed to the model. + :returns: A dictionary with the following keys: + - `replies`: The generated List of `ChatMessage` objects. """ + generation_kwargs = generation_kwargs or {} + generation_kwargs = generation_kwargs.copy() + + streaming_callback = streaming_callback or self.streaming_callback + generation_kwargs["stream"] = streaming_callback is not None - kwargs = kwargs.copy() - messages: List[ChatMessage] = kwargs.pop("messages", []) # check if the prompt is a list of ChatMessage objects if not ( isinstance(messages, list) @@ -177,39 +188,29 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) + body = self.model_adapter.prepare_body( + messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} + ) try: - if self.streaming_callback: + if streaming_callback: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_stream = response["body"] - responses = self.model_adapter.get_stream_responses( - stream=response_stream, stream_handler=self.streaming_callback + replies = self.model_adapter.get_stream_responses( + stream=response_stream, streaming_callback=streaming_callback ) else: response = self.client.invoke_model( body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses(response_body=response_body) + replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception - return responses - - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Generates a list of `ChatMessage` responses to the given messages using the Amazon Bedrock LLM. - - :param messages: The messages to generate a response to. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated list of `ChatMessage` objects. - """ - return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))} + return {"replies": replies} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index b15000aa2..6ef0a4765 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,11 +1,12 @@ import json import logging import re -from typing import Any, ClassVar, Dict, List, Optional, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Type from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack.dataclasses import StreamingChunk +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -25,8 +26,6 @@ ) from .handlers import ( DefaultPromptHandler, - DefaultTokenStreamingHandler, - TokenStreamingHandler, ) logger = logging.getLogger(__name__) @@ -87,6 +86,7 @@ def __init__( aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, truncate: Optional[bool] = True, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, **kwargs, ): """ @@ -100,6 +100,8 @@ def __init__( :param aws_profile_name: The AWS profile name. :param max_length: The maximum length of the generated text. :param truncate: Whether to truncate the prompt or not. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -117,6 +119,7 @@ def __init__( self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.streaming_callback = streaming_callback self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -158,7 +161,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: raise AmazonBedrockConfigurationError(msg) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) - def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + def _ensure_token_limit(self, prompt: str) -> str: """ Ensures that the prompt and answer token lengths together are within the model_max_length specified during the initialization of the component. @@ -166,14 +169,6 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union :param prompt: The prompt to be sent to the model. :returns: The resized prompt. """ - # the prompt for this model will be of the type str - if isinstance(prompt, List): - msg = ( - "AmazonBedrockGenerator only supports a string as a prompt, " - "while currently, the prompt is of type List." - ) - raise ValueError(msg) - resize_info = self.prompt_handler(prompt) if resize_info["prompt_length"] != resize_info["new_prompt_length"]: logger.warning( @@ -187,31 +182,36 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union ) return str(resize_info["resized_prompt"]) - def invoke(self, *args, **kwargs): + @component.output_types(replies=List[str]) + def run( + self, + prompt: str, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ - Invokes the model with the given prompt. + Generates a list of string response to the given prompt. - :param args: Additional positional arguments passed to the generator. - :param kwargs: Additional keyword arguments passed to the generator. - :returns: A list of generated responses (strings). + :param prompt: The prompt to generate a response for. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional keyword arguments passed to the generator. + :returns: A dictionary with the following keys: + - `replies`: A list of generated responses. + :raises ValueError: If the prompt is empty or None. + :raises AmazonBedrockInferenceError: If the model cannot be invoked. """ - kwargs = kwargs.copy() - prompt: str = kwargs.pop("prompt", None) - stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) - - if not prompt or not isinstance(prompt, (str, list)): - msg = ( - f"The model {self.model} requires a valid prompt, but currently, it has no prompt. " - f"Make sure to provide a prompt in the format that the model expects." - ) - raise ValueError(msg) + generation_kwargs = generation_kwargs or {} + generation_kwargs = generation_kwargs.copy() + streaming_callback = streaming_callback or self.streaming_callback + generation_kwargs["stream"] = streaming_callback is not None if self.truncate: prompt = self._ensure_token_limit(prompt) - body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) + body = self.model_adapter.prepare_body(prompt=prompt, **generation_kwargs) try: - if stream: + if streaming_callback: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), modelId=self.model, @@ -219,11 +219,9 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_stream = response["body"] - handler: TokenStreamingHandler = kwargs.get( - "stream_handler", - self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()), + replies = self.model_adapter.get_stream_responses( + stream=response_stream, streaming_callback=streaming_callback ) - responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( body=json.dumps(body), @@ -232,7 +230,7 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses(response_body=response_body) + replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = ( f"Could not connect to Amazon Bedrock model {self.model}. " @@ -241,22 +239,7 @@ def invoke(self, *args, **kwargs): ) raise AmazonBedrockInferenceError(msg) from exception - return responses - - @component.output_types(replies=List[str]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Generates a list of string response to the given prompt. - - :param prompt: Instructions for the model. - :param generation_kwargs: Additional keyword arguments to customize text generation. - These arguments are specific to the model. You can find them in the model's documentation. - :returns: A dictionary with the following keys: - - `replies`: A list of generated responses. - :raises ValueError: If the prompt is empty or None. - :raises AmazonBedrockInferenceError: If the model cannot be invoked. - """ - return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))} + return {"replies": replies} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: @@ -278,6 +261,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, @@ -288,6 +272,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, max_length=self.max_length, truncate=self.truncate, + streaming_callback=callback_name, **self.kwargs, ) @@ -305,4 +290,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": data["init_parameters"], ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index f4dc1aa4f..07db2742f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from typing import Dict, Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast @@ -61,35 +60,3 @@ def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: "model_max_length": self.model_max_length, "max_length": self.max_length, } - - -class TokenStreamingHandler(ABC): - """ - TokenStreamingHandler implementations handle the streaming of tokens from the stream. - """ - - DONE_MARKER = "[DONE]" - - @abstractmethod - def __call__(self, token_received: str, **kwargs) -> str: - """ - This callback method is called when a new token is received from the stream. - - :param token_received: The token received from the stream. - :param kwargs: Additional keyword arguments passed to the handler. - :returns: The token to be sent to the stream. - """ - pass - - -class DefaultTokenStreamingHandler(TokenStreamingHandler): - def __call__(self, token_received, **kwargs) -> str: - """ - This callback method is called when a new token is received from the stream. - - :param token_received: The token received from the stream. - :param kwargs: Additional keyword arguments passed to the handler. - :returns: The token to be sent to the stream. - """ - print(token_received, flush=True, end="") # noqa: T201 - return token_received diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 3e62b56ea..79a04d52b 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -121,15 +121,6 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -def test_invoke_with_no_kwargs(mock_boto3_session): - """ - Test invoke raises an error if no messages are provided - """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2") - with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires"): - layer.invoke() - - @pytest.mark.parametrize( "model, expected_model_adapter", [ diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 65463caae..f0233888c 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, call, patch import pytest +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( @@ -34,6 +35,7 @@ def test_to_dict(mock_boto3_session): "max_length": 99, "truncate": False, "temperature": 10, + "streaming_callback": None, }, } @@ -120,15 +122,6 @@ def test_constructor_with_empty_model(): AmazonBedrockGenerator(model="") -def test_invoke_with_no_kwargs(mock_boto3_session): - """ - Test invoke raises an error if no prompt is provided - """ - layer = AmazonBedrockGenerator(model="anthropic.claude-v2") - with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): - layer.invoke() - - def test_short_prompt_is_not_truncated(mock_boto3_session): """ Test that a short prompt is not truncated @@ -224,13 +217,13 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.model_adapter.get_responses = MagicMock(return_value=["response"]) # Invoke the generator - generator.invoke(prompt=long_prompt_text) + generator.run(prompt=long_prompt_text) # Ensure _ensure_token_limit was not called mock_ensure_token_limit.assert_not_called(), # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text) + generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) @pytest.mark.parametrize( @@ -407,7 +400,7 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"delta": {"text": " This"}}'}}, @@ -417,35 +410,31 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"delta": {"text": " response."}}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"delta": {"text": " This"}}), - call(" is", event_data={"delta": {"text": " is"}}), - call(" a", event_data={"delta": {"text": " a"}}), - call(" single", event_data={"delta": {"text": " single"}}), - call(" response.", event_data={"delta": {"text": " response."}}), + call(StreamingChunk(content=" This", meta={"delta": {"text": " This"}})), + call(StreamingChunk(content=" is", meta={"delta": {"text": " is"}})), + call(StreamingChunk(content=" a", meta={"delta": {"text": " a"}})), + call(StreamingChunk(content=" single", meta={"delta": {"text": " single"}})), + call(StreamingChunk(content=" response.", meta={"delta": {"text": " response."}})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestAnthropicClaudeAdapterNoMessagesAPI: @@ -553,7 +542,7 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"completion": " This"}'}}, @@ -563,35 +552,31 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"completion": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"completion": " This"}), - call(" is", event_data={"completion": " is"}), - call(" a", event_data={"completion": " a"}), - call(" single", event_data={"completion": " single"}), - call(" response.", event_data={"completion": " response."}), + call(StreamingChunk(content=" This", meta={"completion": " This"})), + call(StreamingChunk(content=" is", meta={"completion": " is"})), + call(StreamingChunk(content=" a", meta={"completion": " a"})), + call(StreamingChunk(content=" single", meta={"completion": " single"})), + call(StreamingChunk(content=" response.", meta={"completion": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestMistralAdapter: @@ -686,7 +671,7 @@ def test_get_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"outputs": [{"text": " This"}]}'}}, @@ -696,35 +681,33 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputs": [{"text": " response."}]}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"outputs": [{"text": " This"}]}), - call(" is", event_data={"outputs": [{"text": " is"}]}), - call(" a", event_data={"outputs": [{"text": " a"}]}), - call(" single", event_data={"outputs": [{"text": " single"}]}), - call(" response.", event_data={"outputs": [{"text": " response."}]}), + call(StreamingChunk(content=" This", meta={"outputs": [{"text": " This"}]})), + call(StreamingChunk(content=" is", meta={"outputs": [{"text": " is"}]})), + call(StreamingChunk(content=" a", meta={"outputs": [{"text": " a"}]})), + call(StreamingChunk(content=" single", meta={"outputs": [{"text": " single"}]})), + call(StreamingChunk(content=" response.", meta={"outputs": [{"text": " response."}]})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + streaming_callback_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestCohereCommandAdapter: @@ -881,7 +864,7 @@ def test_get_responses_multiple_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"text": " This"}'}}, @@ -892,36 +875,32 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"text": " This"}), - call(" is", event_data={"text": " is"}), - call(" a", event_data={"text": " a"}), - call(" single", event_data={"text": " single"}), - call(" response.", event_data={"text": " response."}), - call("", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True}), + call(StreamingChunk(content=" This", meta={"text": " This"})), + call(StreamingChunk(content=" is", meta={"text": " is"})), + call(StreamingChunk(content=" a", meta={"text": " a"})), + call(StreamingChunk(content=" single", meta={"text": " single"})), + call(StreamingChunk(content=" response.", meta={"text": " response."})), + call(StreamingChunk(content="", meta={"finish_reason": "MAX_TOKENS", "is_finished": True})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestCohereCommandRAdapter: @@ -1025,11 +1004,11 @@ def test_extract_completions_from_response(self) -> None: completions = adapter._extract_completions_from_response(response_body=response_body) assert completions == ["response"] - def test_extract_token_from_stream(self) -> None: + def test_build_chunk(self) -> None: adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) chunk = {"text": "response_token"} - token = adapter._extract_token_from_stream(chunk=chunk) - assert token == "response_token" + streaming_chunk = adapter._build_streaming_chunk(chunk=chunk) + assert streaming_chunk == StreamingChunk(content="response_token", meta=chunk) class TestAI21LabsJurassic2Adapter: @@ -1288,7 +1267,7 @@ def test_get_responses_multiple_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"outputText": " This"}'}}, @@ -1298,35 +1277,31 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputText": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"outputText": " This"}), - call(" is", event_data={"outputText": " is"}), - call(" a", event_data={"outputText": " a"}), - call(" single", event_data={"outputText": " single"}), - call(" response.", event_data={"outputText": " response."}), + call(StreamingChunk(content=" This", meta={"outputText": " This"})), + call(StreamingChunk(content=" is", meta={"outputText": " is"})), + call(StreamingChunk(content=" a", meta={"outputText": " a"})), + call(StreamingChunk(content=" single", meta={"outputText": " single"})), + call(StreamingChunk(content=" response.", meta={"outputText": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestMetaLlamaAdapter: @@ -1417,7 +1392,7 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"generation": " This"}'}}, @@ -1427,32 +1402,28 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"generation": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"generation": " This"}), - call(" is", event_data={"generation": " is"}), - call(" a", event_data={"generation": " a"}), - call(" single", event_data={"generation": " single"}), - call(" response.", event_data={"generation": " response."}), + call(StreamingChunk(content=" This", meta={"generation": " This"})), + call(StreamingChunk(content=" is", meta={"generation": " is"})), + call(StreamingChunk(content=" a", meta={"generation": " a"})), + call(StreamingChunk(content=" single", meta={"generation": " single"})), + call(StreamingChunk(content=" response.", meta={"generation": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called()