diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 616bb625d..2a4f88960 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] -test = "pytest --reruns 0 --reruns-delay 30 -x {args:tests}" +test = "pytest --reruns 3 --reruns-delay 30 -x {args:tests}" test-cov = "coverage run -m pytest --reruns 3 --reruns-delay 30 -x {args:tests}" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index cbb7a85f0..715ab7ebf 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -23,7 +23,7 @@ class BedrockModelChatAdapter(ABC): def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ - Initializes the chat adapter with the generation kwargs. + Initializes the chat adapter with the truncate parameter and generation kwargs. """ self.generation_kwargs = generation_kwargs self.truncate = truncate @@ -172,6 +172,7 @@ def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Anthropic Claude chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ super().__init__(truncate, generation_kwargs) @@ -218,7 +219,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. + :returns: The prepared chat messages as a dictionary. """ body: Dict[str, Any] = {} system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None @@ -227,6 +228,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: ] if system: body["system"] = system + # Ensure token limit for each message in the body + if self.truncate: + for message in body["messages"]: + for content in message["content"]: + content["text"] = self._ensure_token_limit(content["text"]) return body def check_prompt(self, prompt: str) -> Dict[str, Any]: @@ -321,7 +327,7 @@ class MistralChatAdapter(BedrockModelChatAdapter): def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Mistral chat adapter. - + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ super().__init__(truncate, generation_kwargs) @@ -477,6 +483,7 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ Initializes the Meta Llama 2 chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ super().__init__(truncate, generation_kwargs) @@ -523,7 +530,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) - return self._ensure_token_limit(prepared_prompt) + + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt def check_prompt(self, prompt: str) -> Dict[str, Any]: """ diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8da43b03e..1f2e27c86 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,7 +1,7 @@ import logging import os from typing import Optional, Type -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -141,6 +141,107 @@ def test_invoke_with_no_kwargs(mock_boto3_session): layer.invoke() +def test_short_prompt_is_not_truncated(mock_boto3_session): + """ + Test that a short prompt is not truncated + """ + # Define a short mock prompt and its tokenized version + mock_prompt_text = "I am a tokenized prompt" + mock_prompt_tokens = mock_prompt_text.split() + + # Mock the tokenizer so it returns our predefined tokens + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = mock_prompt_tokens + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Since our mock prompt is 5 tokens long, it doesn't exceed the + # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) + + # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it + assert prompt_after_resize == mock_prompt_text + + +def test_long_prompt_is_truncated(mock_boto3_session): + """ + Test that a long prompt is truncated + """ + # Define a long mock prompt and its tokenized version + long_prompt_text = "I am a tokenized prompt of length eight" + long_prompt_tokens = long_prompt_text.split() + + # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit + truncated_prompt_text = "I am a tokenized prompt of length" + + # Mock the tokenizer to return our predefined tokens + # convert tokens to our predefined truncated text + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = long_prompt_tokens + mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) + + # The prompt exceeds the limit, _ensure_token_limit truncates it + assert prompt_after_resize == truncated_prompt_text + + +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + messages = [ChatMessage.from_system("I am a tokenized prompt of length eight")] + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + truncate=False, + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + generator.model_adapter.get_responses = MagicMock(return_value=["response"]) + + # Invoke the generator + generator.invoke(messages=messages) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[]) + + @pytest.mark.parametrize( "model, expected_model_adapter", [