From ce060ff7ed881ce2f0d926fd65131316139e1069 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 21 Feb 2024 10:04:39 +0100 Subject: [PATCH] bedrock - remove supports method (#456) --- .../generators/amazon_bedrock/generator.py | 42 ------ .../tests/test_amazon_bedrock.py | 124 ------------------ 2 files changed, 166 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 1a8bb04f0..e1820497d 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -155,48 +155,6 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union ) return str(resize_info["resized_prompt"]) - @classmethod - def supports(cls, model, **kwargs): - model_supported = cls.get_model_adapter(model) is not None - if not model_supported or not cls.aws_configured(**kwargs): - return False - - try: - session = cls.get_aws_session(**kwargs) - bedrock = session.client("bedrock") - foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT") - available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])] - model_ids_supporting_streaming = [ - entry["modelId"] - for entry in foundation_models_response.get("modelSummaries", []) - if entry.get("responseStreamingSupported", False) - ] - except AWSConfigurationError as exception: - raise AmazonBedrockConfigurationError(message=exception.message) from exception - except Exception as exception: - msg = ( - "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " - "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" - ) - raise AmazonBedrockConfigurationError(msg) from exception - - model_available = model in available_model_ids - if not model_available: - msg = ( - f"The model {model} is not available in Amazon Bedrock. " - f"Make sure the model you want to use is available in the configured AWS region and " - f"you have access." - ) - raise AmazonBedrockConfigurationError(msg) - - stream: bool = kwargs.get("stream", False) - model_supports_streaming = model in model_ids_supporting_streaming - if stream and not model_supports_streaming: - msg = f"The model {model} doesn't support streaming. Remove the `stream` parameter." - raise AmazonBedrockConfigurationError(msg) - - return model_supported - def invoke(self, *args, **kwargs): kwargs = kwargs.copy() prompt: str = kwargs.pop("prompt", None) diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index f4d10a9b2..e43cc94cf 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, call, patch import pytest -from botocore.exceptions import BotoCoreError from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( @@ -13,7 +12,6 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.mark.unit @@ -203,128 +201,6 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text -@pytest.mark.unit -def test_supports_for_valid_aws_configuration(): - mock_session = MagicMock() - mock_session.client("bedrock").list_foundation_models.return_value = { - "modelSummaries": [{"modelId": "anthropic.claude-v2"}] - } - - # Patch the class method to return the mock session - with patch( - "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ): - supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) - args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args - assert kwargs["byOutputModality"] == "TEXT" - - assert supported - - -@pytest.mark.unit -def test_supports_raises_on_invalid_aws_profile_name(): - with patch("boto3.Session") as mock_boto3_session: - mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_fake_profile", - ) - - -@pytest.mark.unit -def test_supports_for_invalid_bedrock_config(): - mock_session = MagicMock() - mock_session.client.side_effect = BotoCoreError() - - # Patch the class method to return the mock session - with patch( - "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) - - -@pytest.mark.unit -def test_supports_for_invalid_bedrock_config_error_on_list_models(): - mock_session = MagicMock() - mock_session.client("bedrock").list_foundation_models.side_effect = BotoCoreError() - - # Patch the class method to return the mock session - with patch( - "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - ) - - -@pytest.mark.unit -def test_supports_for_no_aws_params(): - supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2") - - assert supported is False - - -@pytest.mark.unit -def test_supports_for_unknown_model(): - supported = AmazonBedrockGenerator.supports(model="unknown_model", aws_profile_name="some_real_profile") - - assert supported is False - - -@pytest.mark.unit -def test_supports_with_stream_true_for_model_that_supports_streaming(): - mock_session = MagicMock() - mock_session.client("bedrock").list_foundation_models.return_value = { - "modelSummaries": [{"modelId": "anthropic.claude-v2", "responseStreamingSupported": True}] - } - - # Patch the class method to return the mock session - with patch( - "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ): - supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", - aws_profile_name="some_real_profile", - stream=True, - ) - - assert supported - - -@pytest.mark.unit -def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): - mock_session = MagicMock() - mock_session.client("bedrock").list_foundation_models.return_value = { - "modelSummaries": [{"modelId": "ai21.j2-mid-v1", "responseStreamingSupported": False}] - } - - # Patch the class method to return the mock session - with patch( - "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ), pytest.raises( - AmazonBedrockConfigurationError, - match="The model ai21.j2-mid-v1 doesn't support streaming.", - ): - AmazonBedrockGenerator.supports( - model="ai21.j2-mid-v1", - aws_profile_name="some_real_profile", - stream=True, - ) - - @pytest.mark.unit @pytest.mark.parametrize( "model, expected_model_adapter",