From 9272da82aee14e346760ca6ec6bf8321015e110e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 6 Mar 2024 10:48:09 +0100 Subject: [PATCH 01/15] Migrate Claude to messaging API --- .../amazon_bedrock/chat/adapters.py | 100 ++++++++---------- .../amazon_bedrock/chat/chat_generator.py | 13 +-- 2 files changed, 53 insertions(+), 60 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 196a55743..6d8bb8e30 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -44,7 +44,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: :param response_body: The response body. :returns: The extracted responses. """ - return self._extract_messages_from_response(self.response_body_message_key(), response_body) + return self._extract_messages_from_response(response_body) def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: tokens: List[str] = [] @@ -53,11 +53,8 @@ def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[St if chunk: decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) token = self._extract_token_from_stream(decoded_chunk) - # take all the rest key/value pairs from the chunk, add them to the metadata - stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} - stream_chunk = StreamingChunk(content=token, meta=stream_metadata) - # callback the stream handler with StreamingChunk - stream_handler(stream_chunk) + stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only + stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] return responses @@ -124,25 +121,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :returns: A dictionary containing the resized prompt and additional information. """ - def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. - :param message_tag: The key for the message in the response body. :param response_body: The response body. :returns: The extracted ChatMessage list. """ - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - @abstractmethod - def response_body_message_key(self) -> str: - """ - Returns the key for the message in the response body. - Subclasses should override this method to return the correct message key - where the response is located. - - :returns: The key for the message in the response body. - """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -183,7 +169,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + max_length=self.generation_kwargs.get("max_tokens") or 512, ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -195,8 +181,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ :returns: The prepared body. """ default_params = { - "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it @@ -204,37 +190,24 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences params = self._get_params(inference_kwargs, default_params) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + body = {**self.prepare_chat_messages(messages=messages), **params} return body - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: """ Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. :returns: The prepared chat messages as a string. """ - conversation = [] - for index, message in enumerate(messages): - if message.is_from(ChatRole.USER): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.ASSISTANT): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.FUNCTION): - error_message = "Anthropic does not support function calls." - raise ValueError(error_message) - elif message.is_from(ChatRole.SYSTEM) and index == 0: - # Until we transition to the new chat message format system messages will be ignored - # see https://docs.anthropic.com/claude/reference/messages_post for more details - logger.warning( - "System messages are not fully supported by the current version of Claude and will be ignored." - ) - else: - invalid_role = f"Invalid role {message.role} for message {message.content}" - raise ValueError(invalid_role) - - prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " - return self._ensure_token_limit(prepared_prompt) + body: Dict[str, Any] = {} + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + body["messages"] = [ + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + ] + if system: + body["system"] = system + return body def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @@ -245,13 +218,19 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Anthropic Claude i.e. "completion". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "completion" + messages: List[ChatMessage] = [] + if response_body.get("type") == "message": + for content in response_body["content"]: + if content.get("type") == "text": + messages.append(ChatMessage.from_assistant(content["text"])) + return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ @@ -260,7 +239,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: The extracted token. """ - return chunk.get("completion", "") + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": + return chunk.get("delta", {}).get("text", "") + return "" + + def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to a dictionary with the content and role fields. + :param m: The ChatMessage to convert. + :return: The dictionary with the content and role fields. + """ + return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -357,13 +346,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Meta Llama 2 i.e. "generation". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "generation" + message_tag = "generation" + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bea6924f6..5279dc001 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator: """ `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. - For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the - 'anthropic.claude-v2' model name. + For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack.dataclasses import ChatMessage from haystack.components.generators.utils import print_streaming_chunk - messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) - client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens": 512}) ``` @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) try: if self.streaming_callback: response = self.client.invoke_model_with_response_stream( From feff4b4be468b2461aa1c907622a36a33b0f5817 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 6 Mar 2024 11:59:58 +0100 Subject: [PATCH 02/15] Adjust unit tests --- .../tests/test_chat_generator.py | 37 +++++++------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 9ba4d5534..52449a85d 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -146,10 +146,10 @@ def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 512, - "stop_sequences": ["\n\nHuman:"], - } + 'anthropic_version': 'bedrock-2023-05-31', + 'max_tokens': 512, + 'messages': [{'content': [{'text': 'Hello, how are you?', 'type': 'text'}], + 'role': 'user'}]} body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -158,14 +158,15 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" - expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 69, - "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], - "temperature": 0.7, - "top_p": 0.8, - "top_k": 5, - } + expected_body = {'anthropic_version': 'bedrock-2023-05-31', + 'max_tokens': 512, + 'max_tokens_to_sample': 69, + 'messages': [{'content': [{'text': 'Hello, how are you?', 'type': 'text'}], + 'role': 'user'}], + 'stop_sequences': ['CUSTOM_STOP'], + 'temperature': 0.7, + 'top_k': 5, + 'top_p': 0.8} body = layer.prepare_body( [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] @@ -173,18 +174,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = AnthropicClaudeChatAdapter(generation_kwargs={}) - response_body = {"completion": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - - assert response_message == [ChatMessage.from_assistant(expected_response)] - class TestMetaLlama2ChatAdapter: @pytest.mark.integration From 8cb1d4bedaafabeea2fecedb4702c366c2b3aae1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 6 Mar 2024 12:05:29 +0100 Subject: [PATCH 03/15] Replace single quotes with double --- .../tests/test_chat_generator.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 52449a85d..c46459a7f 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -146,10 +146,10 @@ def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { - 'anthropic_version': 'bedrock-2023-05-31', - 'max_tokens': 512, - 'messages': [{'content': [{'text': 'Hello, how are you?', 'type': 'text'}], - 'role': 'user'}]} + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], + "role": "user"}]} body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -158,15 +158,15 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" - expected_body = {'anthropic_version': 'bedrock-2023-05-31', - 'max_tokens': 512, - 'max_tokens_to_sample': 69, - 'messages': [{'content': [{'text': 'Hello, how are you?', 'type': 'text'}], - 'role': 'user'}], - 'stop_sequences': ['CUSTOM_STOP'], - 'temperature': 0.7, - 'top_k': 5, - 'top_p': 0.8} + expected_body = {"anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "max_tokens_to_sample": 69, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], + "role": "user"}], + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_k": 5, + "top_p": 0.8} body = layer.prepare_body( [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] From 112199ef71c2df42aa66395de4c0e1b5d33dcf9c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 6 Mar 2024 12:15:53 +0100 Subject: [PATCH 04/15] pylint tests --- .../tests/test_chat_generator.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c46459a7f..196ddb36c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -148,8 +148,8 @@ def test_prepare_body_with_default_params(self) -> None: expected_body = { "anthropic_version": "bedrock-2023-05-31", "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], - "role": "user"}]} + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], + } body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -158,15 +158,16 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" - expected_body = {"anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "max_tokens_to_sample": 69, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], - "role": "user"}], - "stop_sequences": ["CUSTOM_STOP"], - "temperature": 0.7, - "top_k": 5, - "top_p": 0.8} + expected_body = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "max_tokens_to_sample": 69, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_k": 5, + "top_p": 0.8, + } body = layer.prepare_body( [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] From d56f47285edb9b993e630969c3c7c6c2a760c432 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 7 Mar 2024 14:51:55 +0100 Subject: [PATCH 05/15] Add first live integration test --- .../tests/test_chat_generator.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 196ddb36c..04e3cebd9 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -2,7 +2,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, ChatRole from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( @@ -228,3 +228,25 @@ def test_get_responses(self) -> None: assert isinstance(message, ChatMessage) assert response_message == [ChatMessage.from_assistant(expected_response)] + + @pytest.mark.parametrize("model_name", [ + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-v2:1", + "meta.llama2-13b-chat-v1" + ]) + @pytest.mark.integration + def test_default_inference_params(self, model_name): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + client = AmazonBedrockChatGenerator(model=model_name) + response = client.run(messages) + assert response["replies"] + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert isinstance(response["replies"][0], ChatMessage) + assert response["replies"][0].content + assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) + assert "paris" in response["replies"][0].content.lower() \ No newline at end of file From 21f17ea9d9b83ede50f48ce0daa39683210a9f85 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 7 Mar 2024 15:09:01 +0100 Subject: [PATCH 06/15] Fix pylint --- .../amazon_bedrock/tests/test_chat_generator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 04e3cebd9..0ad78b74b 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -229,11 +229,9 @@ def test_get_responses(self) -> None: assert response_message == [ChatMessage.from_assistant(expected_response)] - @pytest.mark.parametrize("model_name", [ - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-v2:1", - "meta.llama2-13b-chat-v1" - ]) + @pytest.mark.parametrize( + "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] + ) @pytest.mark.integration def test_default_inference_params(self, model_name): messages = [ @@ -249,4 +247,4 @@ def test_default_inference_params(self, model_name): assert isinstance(response["replies"][0], ChatMessage) assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) - assert "paris" in response["replies"][0].content.lower() \ No newline at end of file + assert "paris" in response["replies"][0].content.lower() From 67e9e5d13751ad0ae1cdf790c0c37347477f4728 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 12:28:33 +0100 Subject: [PATCH 07/15] Add streaming callback test --- .../tests/test_chat_generator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 0ad78b74b..b6b358a7e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -248,3 +248,24 @@ def test_default_inference_params(self, model_name): assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) assert "paris" in response["replies"][0].content.lower() + + @pytest.mark.parametrize( + "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] + ) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name): + + callback_called = False + + def streaming_callback_verifier(chunk): + nonlocal callback_called + callback_called = True + + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback_verifier) + client.run(messages) + assert callback_called From 4035d50b5f628f09f16a817c3ef64fbbda749852 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 15:54:03 +0100 Subject: [PATCH 08/15] Add AWS credentials --- .github/workflows/amazon_bedrock.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 75f881a50..dd03db545 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -21,6 +21,7 @@ concurrency: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" + AWS_REGION: eu-central-1 jobs: run: @@ -56,5 +57,11 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs + - name: AWS authentication + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 + with: + aws-region: ${{ env.AWS_REGION }} + role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }} + - name: Run tests run: hatch run cov From d6cd42665d474fac279c0717d5aeb22c33f70923 Mon Sep 17 00:00:00 2001 From: Paul Steppacher Date: Fri, 8 Mar 2024 15:11:52 +0000 Subject: [PATCH 09/15] add permissions --- .github/workflows/amazon_bedrock.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index dd03db545..34f1b2a8c 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -18,6 +18,10 @@ concurrency: group: amazon-bedrock-${{ github.head_ref }} cancel-in-progress: true +permissions: + id-token: write + contents: read + env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" From b97c79321223b99e171cb6a976c30e0fe1fcc530 Mon Sep 17 00:00:00 2001 From: Paul Steppacher Date: Fri, 8 Mar 2024 15:17:48 +0000 Subject: [PATCH 10/15] adapt region for bedrock tests --- .github/workflows/amazon_bedrock.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 34f1b2a8c..8b1651764 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -25,7 +25,7 @@ permissions: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" - AWS_REGION: eu-central-1 + AWS_REGION: us-east-1 jobs: run: From 88171149ee764a54ba3c589306d92e35db08c7d6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 17:18:39 +0100 Subject: [PATCH 11/15] Add meta for claude models --- .../components/generators/amazon_bedrock/chat/adapters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 6d8bb8e30..eaa064a3b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -229,7 +229,8 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List if response_body.get("type") == "message": for content in response_body["content"]: if content.get("type") == "text": - messages.append(ChatMessage.from_assistant(content["text"])) + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: From 16f08b657f05a27a180d496d4c6c1beb1f30309f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 17:22:28 +0100 Subject: [PATCH 12/15] Assert non-empty meta --- integrations/amazon_bedrock/tests/test_chat_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index b6b358a7e..8eda7da39 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -248,6 +248,7 @@ def test_default_inference_params(self, model_name): assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) assert "paris" in response["replies"][0].content.lower() + assert len(response["replies"][0].meta) > 0 @pytest.mark.parametrize( "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] From f07d768e8b6630725b2fbf326a7526a62987a2f7 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 8 Mar 2024 17:26:57 +0100 Subject: [PATCH 13/15] Cosmetics --- .../amazon_bedrock/tests/test_chat_generator.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8eda7da39..b40cf09f1 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -11,7 +11,8 @@ MetaLlama2ChatAdapter, ) -clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] def test_to_dict(mock_boto3_session): @@ -24,7 +25,7 @@ def test_to_dict(mock_boto3_session): streaming_callback=print_streaming_chunk, ) expected_dict = { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -47,7 +48,7 @@ def test_from_dict(mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -229,9 +230,7 @@ def test_get_responses(self) -> None: assert response_message == [ChatMessage.from_assistant(expected_response)] - @pytest.mark.parametrize( - "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] - ) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name): messages = [ @@ -248,11 +247,11 @@ def test_default_inference_params(self, model_name): assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) assert "paris" in response["replies"][0].content.lower() + + # validate meta assert len(response["replies"][0].meta) > 0 - @pytest.mark.parametrize( - "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] - ) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_with_streaming(self, model_name): From 2e16d68b721cce4e2b70e16ae3e3ba60dc48e8a2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Sat, 9 Mar 2024 09:43:38 +0100 Subject: [PATCH 14/15] Further improve tests, align streaming response contracts to other chat generators --- .../amazon_bedrock/chat/adapters.py | 11 ++- .../tests/test_chat_generator.py | 84 +++++++++++-------- 2 files changed, 58 insertions(+), 37 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index eaa064a3b..36fc3ab48 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -46,18 +46,21 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ return self._extract_messages_from_response(response_body) - def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_responses( + self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: tokens: List[str] = [] + last_decoded_chunk: Dict[str, Any] = {} for event in stream: chunk = event.get("chunk") if chunk: - decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(decoded_chunk) + last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(last_decoded_chunk) stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] - return responses + return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index b40cf09f1..de21ea76e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -2,7 +2,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( @@ -177,6 +177,15 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages + + class TestMetaLlama2ChatAdapter: @pytest.mark.integration def test_prepare_body_with_default_params(self) -> None: @@ -232,40 +241,49 @@ def test_get_responses(self) -> None: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration - def test_default_inference_params(self, model_name): - messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), - ] + def test_default_inference_params(self, model_name, chat_messages): client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(messages) - assert response["replies"] - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert isinstance(response["replies"][0], ChatMessage) - assert response["replies"][0].content - assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) - assert "paris" in response["replies"][0].content.lower() - - # validate meta - assert len(response["replies"][0].meta) > 0 - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST) - @pytest.mark.integration - def test_default_inference_with_streaming(self, model_name): - - callback_called = False + response = client.run(chat_messages) - def streaming_callback_verifier(chunk): - nonlocal callback_called - callback_called = True + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" - messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), - ] + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" - client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback_verifier) - client.run(messages) - assert callback_called + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name, chat_messages): + streaming_callback_called = False + paris_found_in_response = False + + def streaming_callback(chunk: StreamingChunk): + nonlocal streaming_callback_called, paris_found_in_response + streaming_callback_called = True + assert isinstance(chunk, StreamingChunk) + assert chunk.content is not None + if not paris_found_in_response: + paris_found_in_response = "paris" in chunk.content.lower() + + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) + response = client.run(chat_messages) + + assert streaming_callback_called, "Streaming callback was not called" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" From 8683af04bfbff32fc1f02b3bb2c3ecc5bcc8f545 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 11 Mar 2024 11:08:55 +0100 Subject: [PATCH 15/15] Track allowed model params, log warning --- .../amazon_bedrock/chat/adapters.py | 43 +++++++++++++------ .../tests/test_chat_generator.py | 5 +-- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 36fc3ab48..cdb871f40 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,7 +1,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, ClassVar, Dict, List from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -63,14 +63,18 @@ def get_stream_responses( return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. :param target_dict: The dictionary to update. :param updates_dict: The dictionary with updates. + :param allowed_params: The list of allowed params to use. """ for key, value in updates_dict.items(): + if key not in allowed_params: + logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") + continue if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates target_dict[key] = sorted(set(target_dict[key] + value)) @@ -78,21 +82,24 @@ def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> # Override the value in target_dict target_dict[key] = value - def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + def _get_params( + self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] + ) -> Dict[str, Any]: """ Merges params from inference_kwargs with the default params and self.generation_kwargs. Uses a helper function to merge lists or override values as necessary. :param inference_kwargs: The inference kwargs to merge. :param default_params: The default params to start with. + :param allowed_params: The list of allowed params to use. :returns: The merged params. """ # Start with a copy of default_params kwargs = default_params.copy() # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs) - self._update_params(kwargs, inference_kwargs) + self._update_params(kwargs, self.generation_kwargs, allowed_params) + self._update_params(kwargs, inference_kwargs, allowed_params) return kwargs @@ -148,8 +155,16 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): Model adapter for the Anthropic Claude chat model. """ - ANTHROPIC_USER_TOKEN = "\n\nHuman:" - ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + ALLOWED_PARAMS: ClassVar[List[str]] = [ + "anthropic_version", + "max_tokens", + "stop_sequences", + "temperature", + "top_p", + "top_k", + "system", + ] def __init__(self, generation_kwargs: Dict[str, Any]): """ @@ -185,14 +200,14 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ """ default_params = { "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {**self.prepare_chat_messages(messages=messages), **params} return body @@ -261,6 +276,9 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] + chat_template = ( "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" @@ -320,11 +338,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ """ default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) + # no support for stop words in Meta Llama 2 + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {"prompt": self.prepare_chat_messages(messages=messages), **params} return body diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index de21ea76e..6e0356d42 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -162,7 +162,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None: expected_body = { "anthropic_version": "bedrock-2023-05-31", "max_tokens": 512, - "max_tokens_to_sample": 69, "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, @@ -207,13 +206,13 @@ def test_prepare_body_with_custom_inference_params(self) -> None: generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} ) prompt = "Hello, how are you?" + + # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 69, - "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, - "top_k": 5, } body = layer.prepare_body(