Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: AmazonBedrockChatGenerator - migrate Anthropic chat models to use messaging API #545

Merged
merged 15 commits into from
Mar 11, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
:param response_body: The response body.
:returns: The extracted responses.
"""
return self._extract_messages_from_response(self.response_body_message_key(), response_body)
return self._extract_messages_from_response(response_body)

def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]:
tokens: List[str] = []
Expand All @@ -53,11 +53,8 @@ def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[St
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
# take all the rest key/value pairs from the chunk, add them to the metadata
stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token}
stream_chunk = StreamingChunk(content=token, meta=stream_metadata)
# callback the stream handler with StreamingChunk
stream_handler(stream_chunk)
stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only
stream_handler(stream_chunk) # callback the stream handler with StreamingChunk
tokens.append(token)
responses = ["".join(tokens).lstrip()]
return responses
Expand Down Expand Up @@ -124,25 +121,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
:returns: A dictionary containing the resized prompt and additional information.
"""

def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]:
@abstractmethod
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Extracts the messages from the response body.

:param message_tag: The key for the message in the response body.
:param response_body: The response body.
:returns: The extracted ChatMessage list.
"""
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

@abstractmethod
def response_body_message_key(self) -> str:
"""
Returns the key for the message in the response body.
Subclasses should override this method to return the correct message key - where the response is located.

:returns: The key for the message in the response body.
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -183,7 +169,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512,
max_length=self.generation_kwargs.get("max_tokens") or 512,
)

def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
Expand All @@ -195,46 +181,33 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
:returns: The prepared body.
"""
default_params = {
"max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512,
"stop_sequences": ["\n\nHuman:"],
"anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31",
"max_tokens": self.generation_kwargs.get("max_tokens") or 512,
}

# combine stop words with default stop sequences, remove stop_words as Anthropic does not support it
stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", [])
if stop_sequences:
inference_kwargs["stop_sequences"] = stop_sequences
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
body = {**self.prepare_chat_messages(messages=messages), **params}
return body

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
"""
Prepares the chat messages for the Anthropic Claude request.

:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
"""
conversation = []
for index, message in enumerate(messages):
if message.is_from(ChatRole.USER):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.ASSISTANT):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.FUNCTION):
error_message = "Anthropic does not support function calls."
raise ValueError(error_message)
elif message.is_from(ChatRole.SYSTEM) and index == 0:
# Until we transition to the new chat message format system messages will be ignored
# see https://docs.anthropic.com/claude/reference/messages_post for more details
logger.warning(
"System messages are not fully supported by the current version of Claude and will be ignored."
)
else:
invalid_role = f"Invalid role {message.role} for message {message.content}"
raise ValueError(invalid_role)

prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "
return self._ensure_token_limit(prepared_prompt)
body: Dict[str, Any] = {}
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
body["messages"] = [
self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT)
]
if system:
body["system"] = system
return body

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Expand All @@ -245,13 +218,19 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Anthropic Claude i.e. "completion".
Extracts the messages from the response body.

:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "completion"
messages: List[ChatMessage] = []
if response_body.get("type") == "message":
for content in response_body["content"]:
if content.get("type") == "text":
messages.append(ChatMessage.from_assistant(content["text"]))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand All @@ -260,7 +239,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
:param chunk: The streaming chunk.
:returns: The extracted token.
"""
return chunk.get("completion", "")
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
return chunk.get("delta", {}).get("text", "")
return ""

def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
"""
Convert a ChatMessage to a dictionary with the content and role fields.
:param m: The ChatMessage to convert.
:return: The dictionary with the content and role fields.
"""
return {"content": [{"type": "text", "text": m.content}], "role": m.role.value}


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
Expand Down Expand Up @@ -357,13 +346,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Meta Llama 2 i.e. "generation".
Extracts the messages from the response body.

:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "generation"
message_tag = "generation"
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator:
"""
`AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs.

For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-v2' model name.
For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-3-sonnet-20240229-v1:0' model name.

```python
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.components.generators.utils import print_streaming_chunk

messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"),
ChatMessage.from_user("What's Natural Language Processing?")]


client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens_to_sample": 512})
client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0",
streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens": 512})

```

Expand Down Expand Up @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs):
msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
raise ValueError(msg)

body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs)
body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs})
try:
if self.streaming_callback:
response = self.client.invoke_model_with_response_stream(
Expand Down
26 changes: 8 additions & 18 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def test_prepare_body_with_default_params(self) -> None:
layer = AnthropicClaudeChatAdapter(generation_kwargs={})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ",
"max_tokens_to_sample": 512,
"stop_sequences": ["\n\nHuman:"],
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 512,
"messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}],
}

body = layer.prepare_body([ChatMessage.from_user(prompt)])
Expand All @@ -159,12 +159,14 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ",
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 512,
"max_tokens_to_sample": 69,
"stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"],
"messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}],
"stop_sequences": ["CUSTOM_STOP"],
"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
"top_p": 0.8,
}

body = layer.prepare_body(
Expand All @@ -173,18 +175,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None:

assert body == expected_body

@pytest.mark.integration
def test_get_responses(self) -> None:
adapter = AnthropicClaudeChatAdapter(generation_kwargs={})
response_body = {"completion": "This is a single response."}
expected_response = "This is a single response."
response_message = adapter.get_responses(response_body)
# assert that the type of each item in the list is a ChatMessage
for message in response_message:
assert isinstance(message, ChatMessage)

assert response_message == [ChatMessage.from_assistant(expected_response)]


class TestMetaLlama2ChatAdapter:
@pytest.mark.integration
Expand Down