Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support streaming_callback param in amazon bedrock generators #927

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from .handlers import TokenStreamingHandler
from haystack.dataclasses import StreamingChunk


class BedrockModelAdapter(ABC):
Expand Down Expand Up @@ -39,22 +39,24 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
responses = [completion.lstrip() for completion in completions]
return responses

def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock streaming response.

:param stream: The streaming response from the Amazon Bedrock request.
:param stream_handler: The handler for the streaming response.
:param streaming_callback: The handler for the streaming response.
:returns: A list of string responses.
"""
tokens: List[str] = []
streaming_chunks: List[StreamingChunk] = []
for event in stream:
chunk = event.get("chunk")
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
tokens.append(stream_handler(token, event_data=decoded_chunk))
responses = ["".join(tokens).lstrip()]
streaming_chunk: StreamingChunk = self._build_streaming_chunk(decoded_chunk)
streaming_chunks.append(streaming_chunk)
streaming_callback(streaming_chunk)

responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()]
return responses

def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -84,12 +86,12 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""


Expand Down Expand Up @@ -150,17 +152,17 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L

return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
if self.use_messages_api:
return chunk.get("delta", {}).get("text", "")
return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk)

return chunk.get("completion", "")
return StreamingChunk(content=chunk.get("completion", ""), meta=chunk)


class MistralAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -199,17 +201,18 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""
return [output.get("text", "") for output in response_body.get("outputs", [])]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
content = ""
chunk_list = chunk.get("outputs", [])
if chunk_list:
return chunk_list[0].get("text", "")
return ""
content = chunk_list[0].get("text", "")
return StreamingChunk(content=content, meta=chunk)


class CohereCommandAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -254,14 +257,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [generation["text"] for generation in response_body["generations"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("text", "")
return StreamingChunk(content=chunk.get("text", ""), meta=chunk)


class CohereCommandRAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -313,15 +316,15 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [response_body["text"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
token: str = chunk.get("text", "")
return token
return StreamingChunk(content=token, meta=chunk)


class AI21LabsJurassic2Adapter(BedrockModelAdapter):
Expand Down Expand Up @@ -357,7 +360,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [completion["data"]["text"] for completion in response_body["completions"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
msg = "Streaming is not supported for AI21 Jurassic 2 models."
raise NotImplementedError(msg)

Expand Down Expand Up @@ -398,14 +401,14 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
responses = [result["outputText"] for result in response_body["results"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("outputText", "")
return StreamingChunk(content=chunk.get("outputText", ""), meta=chunk)


class MetaLlamaAdapter(BedrockModelAdapter):
Expand Down Expand Up @@ -442,11 +445,11 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
"""
return [response_body["generation"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: A string token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("generation", "")
return StreamingChunk(content=chunk.get("generation", ""), meta=chunk)
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,18 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
return self._extract_messages_from_response(response_body)

def get_stream_responses(
self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]
self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None]
) -> List[ChatMessage]:
tokens: List[str] = []
streaming_chunks: List[StreamingChunk] = []
last_decoded_chunk: Dict[str, Any] = {}
for event in stream:
chunk = event.get("chunk")
if chunk:
last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(last_decoded_chunk)
stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only
stream_handler(stream_chunk) # callback the stream handler with StreamingChunk
tokens.append(token)
responses = ["".join(tokens).lstrip()]
streaming_chunk = self._build_streaming_chunk(last_decoded_chunk)
streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk
streaming_chunks.append(streaming_chunk)
responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()]
return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses]

@staticmethod
Expand Down Expand Up @@ -142,12 +141,12 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""


Expand Down Expand Up @@ -252,16 +251,16 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
messages.append(ChatMessage.from_assistant(content["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
return chunk.get("delta", {}).get("text", "")
return ""
return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk)
return StreamingChunk(content="", meta=chunk)

def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -425,17 +424,17 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
messages.append(ChatMessage.from_assistant(response["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
response_chunk = chunk.get("outputs", [])
if response_chunk:
return response_chunk[0].get("text", "")
return ""
return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk)
return StreamingChunk(content="", meta=chunk)


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Expand Down Expand Up @@ -543,11 +542,11 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
"""
Extracts the token from a streaming chunk.
Extracts the content and meta from a streaming chunk.

:param chunk: The streaming chunk.
:returns: The extracted token.
:param chunk: The streaming chunk as dict.
:returns: A StreamingChunk object.
"""
return chunk.get("generation", "")
return StreamingChunk(content=chunk.get("generation", ""), meta=chunk)
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,29 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
stacklevel=2,
)

def invoke(self, *args, **kwargs):
@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invokes the Amazon Bedrock LLM with the given parameters. The parameters are passed to the Amazon Bedrock
client.
Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM.

:param args: The positional arguments passed to the generator.
:param kwargs: The keyword arguments passed to the generator.
:returns: List of `ChatMessage` generated by LLM.
:param messages: The messages to generate a response to.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs: Additional generation keyword arguments passed to the model.
:returns: A dictionary with the following keys:
- `replies`: The generated List of `ChatMessage` objects.
"""
generation_kwargs = generation_kwargs or {}
generation_kwargs = generation_kwargs.copy()

streaming_callback = streaming_callback or self.streaming_callback
generation_kwargs["stream"] = streaming_callback is not None

kwargs = kwargs.copy()
messages: List[ChatMessage] = kwargs.pop("messages", [])
# check if the prompt is a list of ChatMessage objects
if not (
isinstance(messages, list)
Expand All @@ -177,39 +188,29 @@ def invoke(self, *args, **kwargs):
msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
raise ValueError(msg)

body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs})
body = self.model_adapter.prepare_body(
messages=messages, **{"stop_words": self.stop_words, **generation_kwargs}
)
try:
if self.streaming_callback:
if streaming_callback:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_stream = response["body"]
responses = self.model_adapter.get_stream_responses(
stream=response_stream, stream_handler=self.streaming_callback
replies = self.model_adapter.get_stream_responses(
stream=response_stream, streaming_callback=streaming_callback
)
else:
response = self.client.invoke_model(
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
replies = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}"
raise AmazonBedrockInferenceError(msg) from exception

return responses

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Generates a list of `ChatMessage` responses to the given messages using the Amazon Bedrock LLM.

:param messages: The messages to generate a response to.
:param generation_kwargs: Additional generation keyword arguments passed to the model.
:returns: A dictionary with the following keys:
- `replies`: The generated list of `ChatMessage` objects.
"""
return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))}
return {"replies": replies}

@classmethod
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]:
Expand Down
Loading