Skip to content

Commit

Permalink
Hatch lint
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 2, 2024
1 parent 5bd2131 commit 7056553
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, An
for key, value in updates_dict.items():
if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list):
# Merge lists and remove duplicates
target_dict[key] = list(sorted(set(target_dict[key] + value)))
target_dict[key] = sorted(set(target_dict[key] + value))
else:
# Override the value in target_dict
target_dict[key] = value
Expand Down Expand Up @@ -129,15 +129,17 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
elif message.is_from(ChatRole.ASSISTANT):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.FUNCTION):
raise ValueError("anthropic does not support function calls.")
error_message = "Anthropic does not support function calls."
raise ValueError(error_message)
elif message.is_from(ChatRole.SYSTEM) and index == 0:
# Until we transition to the new chat message format system messages will be ignored
# see https://docs.anthropic.com/claude/reference/messages_post for more details
logger.warning(
"System messages are not fully supported by the current version of Claude and will be ignored."
)
else:
raise ValueError(f"Unsupported message role: {message.role}")
invalid_role = f"Invalid role {message.role} for message {message.content}"
raise ValueError(invalid_role)

return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from haystack.components.generators.utils import deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk

from ..errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError
from haystack_integrations.components.generators.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
AWSConfigurationError,
)

from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter

logger = logging.getLogger(__name__)
Expand Down
29 changes: 14 additions & 15 deletions integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

import pytest
from haystack.components.generators.utils import default_streaming_callback
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage

from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import (
AnthropicClaudeChatAdapter,
BedrockModelChatAdapter,
MetaLlama2ChatAdapter,
AnthropicClaudeChatAdapter, BedrockModelChatAdapter,
)

clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"


@pytest.fixture
def mock_auto_tokenizer():
Expand All @@ -30,7 +33,7 @@ def mock_boto3_session():
@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler

Expand All @@ -49,9 +52,8 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
generation_kwargs={"temperature": 0.7},
streaming_callback=default_streaming_callback,
)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator",
"type": clazz,
"init_parameters": {
"model": "anthropic.claude-v2",
"generation_kwargs": {"temperature": 0.7},
Expand All @@ -69,7 +71,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
"""
generator = AmazonBedrockChatGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator",
"type": clazz,
"init_parameters": {
"model": "anthropic.claude-v2",
"generation_kwargs": {"temperature": 0.7},
Expand Down Expand Up @@ -180,13 +182,11 @@ def test_prepare_body_with_default_params(self) -> None:
assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7,
"top_p": 0.8,
"top_k": 4})
layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ",
'max_tokens_to_sample': 69,
"max_tokens_to_sample": 69,
"stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"],
"temperature": 0.7,
"top_p": 0.8,
Expand Down Expand Up @@ -218,14 +218,13 @@ def test_prepare_body_with_default_params(self) -> None:
assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MetaLlama2ChatAdapter(generation_kwargs={"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
"stop_sequences": ["CUSTOM_STOP"]})
layer = MetaLlama2ChatAdapter(
generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}
)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
'max_gen_len': 69,
"max_gen_len": 69,
"stop_sequences": ["CUSTOM_STOP"],
"temperature": 0.7,
"top_p": 0.8,
Expand Down

0 comments on commit 7056553

Please sign in to comment.