From ac0e4c2f8c8d0dce7a32e8e3a3fe74362b0686dd Mon Sep 17 00:00:00 2001 From: Abraham Yusuf Date: Thu, 17 Oct 2024 10:55:01 +0200 Subject: [PATCH] feat: add prefixes to supported model patterns to allow cross region model ids (#1127) * feat: add prefixes to supported model patterns to allow cross region model ids --- .../amazon_bedrock/chat/chat_generator.py | 6 +++--- .../generators/amazon_bedrock/generator.py | 14 +++++++------- .../amazon_bedrock/tests/test_chat_generator.py | 6 ++++-- .../amazon_bedrock/tests/test_generator.py | 7 ++++++- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index e1732646a..6bb3cc301 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -58,9 +58,9 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"(.+\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, - r"mistral.*": MistralChatAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, + r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, + r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, } def __init__( diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 193332009..c6c814de4 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -66,13 +66,13 @@ class AmazonBedrockGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { - r"amazon.titan-text.*": AmazonTitanAdapter, - r"ai21.j2.*": AI21LabsJurassic2Adapter, - r"cohere.command-[^r].*": CohereCommandAdapter, - r"cohere.command-r.*": CohereCommandRAdapter, - r"(.+\.)?anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama.*": MetaLlamaAdapter, - r"mistral.*": MistralAdapter, + r"([a-z]{2}\.)?amazon.titan-text.*": AmazonTitanAdapter, + r"([a-z]{2}\.)?ai21.j2.*": AI21LabsJurassic2Adapter, + r"([a-z]{2}\.)?cohere.command-[^r].*": CohereCommandAdapter, + r"([a-z]{2}\.)?cohere.command-r.*": CohereCommandRAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeAdapter, + r"([a-z]{2}\.)?meta.llama.*": MetaLlamaAdapter, + r"([a-z]{2}\.)?mistral.*": MistralAdapter, } def __init__( diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 49abc0979..571e03eb2 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -243,7 +243,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(messages=messages) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) @@ -261,6 +261,9 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference ("unknown_model", None), ], ) @@ -517,7 +520,6 @@ def test_get_responses(self) -> None: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) response = client.run(chat_messages) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 2ccd5a3fa..79246b4aa 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -225,7 +225,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(prompt=long_prompt_text) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) @@ -251,10 +251,13 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial ("amazon.titan-text-lite-v1", AmazonTitanAdapter), ("amazon.titan-text-express-v1", AmazonTitanAdapter), + ("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), ("meta.llama2-70b-chat-v1", MetaLlamaAdapter), + ("eu.meta.llama2-13b-chat-v1", MetaLlamaAdapter), # cross-region inference + ("us.meta.llama2-70b-chat-v1", MetaLlamaAdapter), # cross-region inference ("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial ("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter), ("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter), @@ -262,6 +265,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), ("mistral.mistral-large-2402-v1:0", MistralAdapter), + ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference + ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ("unknown_model", None), ],