From 705655338e1e77e47a67552a8f8edea65d3c4781 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 2 Feb 2024 16:55:56 +0100 Subject: [PATCH] Hatch lint --- .../amazon_bedrock/chat/adapters.py | 8 +++-- .../amazon_bedrock/chat/chat_generator.py | 7 ++++- .../tests/test_amazon_chat_bedrock.py | 29 +++++++++---------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 04b3e3ded..fb978e092 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -52,7 +52,7 @@ def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, An for key, value in updates_dict.items(): if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates - target_dict[key] = list(sorted(set(target_dict[key] + value))) + target_dict[key] = sorted(set(target_dict[key] + value)) else: # Override the value in target_dict target_dict[key] = value @@ -129,7 +129,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: elif message.is_from(ChatRole.ASSISTANT): conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") elif message.is_from(ChatRole.FUNCTION): - raise ValueError("anthropic does not support function calls.") + error_message = "Anthropic does not support function calls." + raise ValueError(error_message) elif message.is_from(ChatRole.SYSTEM) and index == 0: # Until we transition to the new chat message format system messages will be ignored # see https://docs.anthropic.com/claude/reference/messages_post for more details @@ -137,7 +138,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: "System messages are not fully supported by the current version of Claude and will be ignored." ) else: - raise ValueError(f"Unsupported message role: {message.role}") + invalid_role = f"Invalid role {message.role} for message {message.content}" + raise ValueError(invalid_role) return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 00e77be32..94bec3a72 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -9,7 +9,12 @@ from haystack.components.generators.utils import deserialize_callback_handler from haystack.dataclasses import ChatMessage, StreamingChunk -from ..errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError +from haystack_integrations.components.generators.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) + from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter logger = logging.getLogger(__name__) diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 3c0277627..e045fb790 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -3,14 +3,17 @@ import pytest from haystack.components.generators.utils import default_streaming_callback -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( + AnthropicClaudeChatAdapter, + BedrockModelChatAdapter, MetaLlama2ChatAdapter, - AnthropicClaudeChatAdapter, BedrockModelChatAdapter, ) +clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + @pytest.fixture def mock_auto_tokenizer(): @@ -30,7 +33,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -49,9 +52,8 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): generation_kwargs={"temperature": 0.7}, streaming_callback=default_streaming_callback, ) - expected_dict = { - "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "type": clazz, "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, @@ -69,7 +71,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator", + "type": clazz, "init_parameters": { "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, @@ -180,13 +182,11 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, - "top_p": 0.8, - "top_k": 4}) + layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - 'max_tokens_to_sample': 69, + "max_tokens_to_sample": 69, "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, @@ -218,14 +218,13 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter(generation_kwargs={"temperature": 0.7, - "top_p": 0.8, - "top_k": 5, - "stop_sequences": ["CUSTOM_STOP"]}) + layer = MetaLlama2ChatAdapter( + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} + ) prompt = "Hello, how are you?" expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", - 'max_gen_len': 69, + "max_gen_len": 69, "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8,