Skip to content

Commit

Permalink
bedrock - remove supports method (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Feb 21, 2024
1 parent cc32ee6 commit ce060ff
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,48 +155,6 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
)
return str(resize_info["resized_prompt"])

@classmethod
def supports(cls, model, **kwargs):
model_supported = cls.get_model_adapter(model) is not None
if not model_supported or not cls.aws_configured(**kwargs):
return False

try:
session = cls.get_aws_session(**kwargs)
bedrock = session.client("bedrock")
foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT")
available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])]
model_ids_supporting_streaming = [
entry["modelId"]
for entry in foundation_models_response.get("modelSummaries", [])
if entry.get("responseStreamingSupported", False)
]
except AWSConfigurationError as exception:
raise AmazonBedrockConfigurationError(message=exception.message) from exception
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

model_available = model in available_model_ids
if not model_available:
msg = (
f"The model {model} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and "
f"you have access."
)
raise AmazonBedrockConfigurationError(msg)

stream: bool = kwargs.get("stream", False)
model_supports_streaming = model in model_ids_supporting_streaming
if stream and not model_supports_streaming:
msg = f"The model {model} doesn't support streaming. Remove the `stream` parameter."
raise AmazonBedrockConfigurationError(msg)

return model_supported

def invoke(self, *args, **kwargs):
kwargs = kwargs.copy()
prompt: str = kwargs.pop("prompt", None)
Expand Down
124 changes: 0 additions & 124 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from unittest.mock import MagicMock, call, patch

import pytest
from botocore.exceptions import BotoCoreError

from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
from haystack_integrations.components.generators.amazon_bedrock.adapters import (
Expand All @@ -13,7 +12,6 @@
CohereCommandAdapter,
MetaLlama2ChatAdapter,
)
from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError


@pytest.mark.unit
Expand Down Expand Up @@ -203,128 +201,6 @@ def test_long_prompt_is_truncated(mock_boto3_session):
assert prompt_after_resize == truncated_prompt_text


@pytest.mark.unit
def test_supports_for_valid_aws_configuration():
mock_session = MagicMock()
mock_session.client("bedrock").list_foundation_models.return_value = {
"modelSummaries": [{"modelId": "anthropic.claude-v2"}]
}

# Patch the class method to return the mock session
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args
assert kwargs["byOutputModality"] == "TEXT"

assert supported


@pytest.mark.unit
def test_supports_raises_on_invalid_aws_profile_name():
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_fake_profile",
)


@pytest.mark.unit
def test_supports_for_invalid_bedrock_config():
mock_session = MagicMock()
mock_session.client.side_effect = BotoCoreError()

# Patch the class method to return the mock session
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)


@pytest.mark.unit
def test_supports_for_invalid_bedrock_config_error_on_list_models():
mock_session = MagicMock()
mock_session.client("bedrock").list_foundation_models.side_effect = BotoCoreError()

# Patch the class method to return the mock session
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)


@pytest.mark.unit
def test_supports_for_no_aws_params():
supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2")

assert supported is False


@pytest.mark.unit
def test_supports_for_unknown_model():
supported = AmazonBedrockGenerator.supports(model="unknown_model", aws_profile_name="some_real_profile")

assert supported is False


@pytest.mark.unit
def test_supports_with_stream_true_for_model_that_supports_streaming():
mock_session = MagicMock()
mock_session.client("bedrock").list_foundation_models.return_value = {
"modelSummaries": [{"modelId": "anthropic.claude-v2", "responseStreamingSupported": True}]
}

# Patch the class method to return the mock session
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
stream=True,
)

assert supported


@pytest.mark.unit
def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
mock_session = MagicMock()
mock_session.client("bedrock").list_foundation_models.return_value = {
"modelSummaries": [{"modelId": "ai21.j2-mid-v1", "responseStreamingSupported": False}]
}

# Patch the class method to return the mock session
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(
AmazonBedrockConfigurationError,
match="The model ai21.j2-mid-v1 doesn't support streaming.",
):
AmazonBedrockGenerator.supports(
model="ai21.j2-mid-v1",
aws_profile_name="some_real_profile",
stream=True,
)


@pytest.mark.unit
@pytest.mark.parametrize(
"model, expected_model_adapter",
Expand Down

0 comments on commit ce060ff

Please sign in to comment.