Skip to content

Commit

Permalink
Added tests to check truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
Amna Mubashar authored and Amna Mubashar committed Aug 8, 2024
1 parent b98359d commit 80cb9d0
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 6 deletions.
2 changes: 1 addition & 1 deletion integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
"haystack-pydoc-tools",
]
[tool.hatch.envs.default.scripts]
test = "pytest --reruns 0 --reruns-delay 30 -x {args:tests}"
test = "pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
test-cov = "coverage run -m pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BedrockModelChatAdapter(ABC):

def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
"""
Initializes the chat adapter with the generation kwargs.
Initializes the chat adapter with the truncate parameter and generation kwargs.
"""
self.generation_kwargs = generation_kwargs
self.truncate = truncate
Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
"""
Initializes the Anthropic Claude chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(truncate, generation_kwargs)
Expand Down Expand Up @@ -218,7 +219,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
Prepares the chat messages for the Anthropic Claude request.
:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
:returns: The prepared chat messages as a dictionary.
"""
body: Dict[str, Any] = {}
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
Expand All @@ -227,6 +228,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
]
if system:
body["system"] = system
# Ensure token limit for each message in the body
if self.truncate:
for message in body["messages"]:
for content in message["content"]:
content["text"] = self._ensure_token_limit(content["text"])
return body

def check_prompt(self, prompt: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -321,7 +327,7 @@ class MistralChatAdapter(BedrockModelChatAdapter):
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
"""
Initializes the Mistral chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(truncate, generation_kwargs)
Expand Down Expand Up @@ -477,6 +483,7 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
"""
Initializes the Meta Llama 2 chat adapter.
:param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(truncate, generation_kwargs)
Expand Down Expand Up @@ -523,7 +530,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)

if self.truncate:
prepared_prompt = self._ensure_token_limit(prepared_prompt)
return prepared_prompt

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Expand Down
103 changes: 102 additions & 1 deletion integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from typing import Optional, Type
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from haystack.components.generators.utils import print_streaming_chunk
Expand Down Expand Up @@ -141,6 +141,107 @@ def test_invoke_with_no_kwargs(mock_boto3_session):
layer.invoke()


def test_short_prompt_is_not_truncated(mock_boto3_session):
"""
Test that a short prompt is not truncated
"""
# Define a short mock prompt and its tokenized version
mock_prompt_text = "I am a tokenized prompt"
mock_prompt_tokens = mock_prompt_text.split()

# Mock the tokenizer so it returns our predefined tokens
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = mock_prompt_tokens

# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Since our mock prompt is 5 tokens long, it doesn't exceed the
# total limit (5 prompt tokens + 3 generated tokens < 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockChatGenerator(
"anthropic.claude-v2",
generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
)
prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text)

# The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it
assert prompt_after_resize == mock_prompt_text


def test_long_prompt_is_truncated(mock_boto3_session):
"""
Test that a long prompt is truncated
"""
# Define a long mock prompt and its tokenized version
long_prompt_text = "I am a tokenized prompt of length eight"
long_prompt_tokens = long_prompt_text.split()

# _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit
truncated_prompt_text = "I am a tokenized prompt of length"

# Mock the tokenizer to return our predefined tokens
# convert tokens to our predefined truncated text
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = long_prompt_tokens
mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text

# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockChatGenerator(
"anthropic.claude-v2",
generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
)
prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text)

# The prompt exceeds the limit, _ensure_token_limit truncates it
assert prompt_after_resize == truncated_prompt_text


def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
messages = [ChatMessage.from_system("I am a tokenized prompt of length eight")]

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
generator = AmazonBedrockChatGenerator(
model="anthropic.claude-v2",
truncate=False,
generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
)

# Mock the _ensure_token_limit method to track if it is called
with patch.object(
generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit
) as mock_ensure_token_limit:
# Mock the model adapter to avoid actual invocation
generator.model_adapter.prepare_body = MagicMock(return_value={})
generator.client = MagicMock()
generator.client.invoke_model = MagicMock(
return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
)
generator.model_adapter.get_responses = MagicMock(return_value=["response"])

# Invoke the generator
generator.invoke(messages=messages)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[])


@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down

0 comments on commit 80cb9d0

Please sign in to comment.