Skip to content

Commit

Permalink
fix: support streaming_callback param in amazon bedrock generators (#927
Browse files Browse the repository at this point in the history
)

* fix: support streaming_callback param in amazon bedrock generators

* fix chat generator merge

* reformat

---------

Co-authored-by: Thomas Stadelmann <[email protected]>
  • Loading branch information
2 people authored and Amnah199 committed Oct 2, 2024
1 parent 70ba8f6 commit 38c9ee1
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 296 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from .handlers import TokenStreamingHandler
from haystack.dataclasses import StreamingChunk


class BedrockModelAdapter(ABC):
Expand Down Expand Up @@ -39,22 +39,24 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
responses = [completion.lstrip() for completion in completions]
return responses

def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock streaming response.
:param stream: The streaming response from the Amazon Bedrock request.
:param stream_handler: The handler for the streaming response.
:param streaming_callback: The handler for the streaming response.
:returns: A list of string responses.
"""
tokens: List[str] = []
streaming_chunks: List[StreamingChunk] = []
for event in stream:
chunk = event.get("chunk")
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
tokens.append(stream_handler(token, event_data=decoded_chunk))
responses = ["".join(tokens).lstrip()]
streaming_chunk: StreamingChunk = self._build_streaming_chunk(decoded_chunk)
streaming_chunks.append(streaming_chunk)
streaming_callback(streaming_chunk)

responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()]
return responses

def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -84,12 +86,12 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""


Expand Down Expand Up @@ -150,17 +152,17 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L

return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
if self.use_messages_api:
return chunk.get("delta", {}).get("text", "")
return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk)

return chunk.get("completion", "")
return StreamingChunk(content=chunk.get("completion", ""), meta=chunk)


class MistralAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -199,17 +201,18 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""
return [output.get("text", "") for output in response_body.get("outputs", [])]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
content = ""
chunk_list = chunk.get("outputs", [])
if chunk_list:
return chunk_list[0].get("text", "")
return ""
content = chunk_list[0].get("text", "")
return StreamingChunk(content=content, meta=chunk)


class CohereCommandAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -254,14 +257,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [generation["text"] for generation in response_body["generations"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("text", "")
return StreamingChunk(content=chunk.get("text", ""), meta=chunk)


class CohereCommandRAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -313,15 +316,15 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [response_body["text"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
token: str = chunk.get("text", "")
return token
return StreamingChunk(content=token, meta=chunk)


class AI21LabsJurassic2Adapter(BedrockModelAdapter):
Expand Down Expand Up @@ -357,7 +360,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [completion["data"]["text"] for completion in response_body["completions"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
msg = "Streaming is not supported for AI21 Jurassic 2 models."
raise NotImplementedError(msg)

Expand Down Expand Up @@ -398,14 +401,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [result["outputText"] for result in response_body["results"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("outputText", "")
return StreamingChunk(content=chunk.get("outputText", ""), meta=chunk)


class MetaLlamaAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -442,11 +445,11 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""
return [response_body["generation"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("generation", "")
return StreamingChunk(content=chunk.get("generation", ""), meta=chunk)
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,18 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
return self._extract_messages_from_response(response_body)

def get_stream_responses(
self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]
self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None]
) -> List[ChatMessage]:
tokens: List[str] = []
streaming_chunks: List[StreamingChunk] = []
last_decoded_chunk: Dict[str, Any] = {}
for event in stream:
chunk = event.get("chunk")
if chunk:
last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(last_decoded_chunk)
stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only
stream_handler(stream_chunk) # callback the stream handler with StreamingChunk
tokens.append(token)
responses = ["".join(tokens).lstrip()]
streaming_chunk = self._build_streaming_chunk(last_decoded_chunk)
streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk
streaming_chunks.append(streaming_chunk)
responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()]
return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses]

@staticmethod
Expand Down Expand Up @@ -142,12 +141,12 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""


Expand Down Expand Up @@ -252,16 +251,16 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
messages.append(ChatMessage.from_assistant(content["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
return chunk.get("delta", {}).get("text", "")
return ""
return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk)
return StreamingChunk(content="", meta=chunk)

def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -425,17 +424,17 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
messages.append(ChatMessage.from_assistant(response["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
response_chunk = chunk.get("outputs", [])
if response_chunk:
return response_chunk[0].get("text", "")
return ""
return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk)
return StreamingChunk(content="", meta=chunk)


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Expand Down Expand Up @@ -543,11 +542,11 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.
:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("generation", "")
return StreamingChunk(content=chunk.get("generation", ""), meta=chunk)
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,29 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
stacklevel=2,
)

def invoke(self, *args, **kwargs):
@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invokes the Amazon Bedrock LLM with the given parameters. The parameters are passed to the Amazon Bedrock
client.
Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM.
:param args: The positional arguments passed to the generator.
:param kwargs: The keyword arguments passed to the generator.
:returns: List of `ChatMessage` generated by LLM.
:param messages: The messages to generate a response to.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs: Additional generation keyword arguments passed to the model.
:returns: A dictionary with the following keys:
- `replies`: The generated List of `ChatMessage` objects.
"""
generation_kwargs = generation_kwargs or {}
generation_kwargs = generation_kwargs.copy()

streaming_callback = streaming_callback or self.streaming_callback
generation_kwargs["stream"] = streaming_callback is not None

kwargs = kwargs.copy()
messages: List[ChatMessage] = kwargs.pop("messages", [])
# check if the prompt is a list of ChatMessage objects
if not (
isinstance(messages, list)
Expand All @@ -177,39 +188,29 @@ def invoke(self, *args, **kwargs):
msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
raise ValueError(msg)

body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs})
body = self.model_adapter.prepare_body(
messages=messages, **{"stop_words": self.stop_words, **generation_kwargs}
)
try:
if self.streaming_callback:
if streaming_callback:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_stream = response["body"]
responses = self.model_adapter.get_stream_responses(
stream=response_stream, stream_handler=self.streaming_callback
replies = self.model_adapter.get_stream_responses(
stream=response_stream, streaming_callback=streaming_callback
)
else:
response = self.client.invoke_model(
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
replies = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}"
raise AmazonBedrockInferenceError(msg) from exception

return responses

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Generates a list of `ChatMessage` responses to the given messages using the Amazon Bedrock LLM.
:param messages: The messages to generate a response to.
:param generation_kwargs: Additional generation keyword arguments passed to the model.
:returns: A dictionary with the following keys:
- `replies`: The generated list of `ChatMessage` objects.
"""
return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))}
return {"replies": replies}

@classmethod
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]:
Expand Down
Loading

0 comments on commit 38c9ee1

Please sign in to comment.