From d28ffa1b7ab4a6e8363da2644996b70150145faa Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 16 Feb 2024 12:00:53 +0100 Subject: [PATCH] Adopt Secret to Amazon Bedrock (#416) * initial import * addin Secret and fixing tests * cleaning * using staticmethod directly * removing ignore from B008 config and setting up in-line * addin Secret and fixing tests for chat component * addin Secret and fixing tests for the chat component * fixing to_dict from_dict * fixing to_dict from_dict to the chat component as well --- integrations/amazon_bedrock/pyproject.toml | 13 +++-- .../amazon_bedrock/chat/chat_generator.py | 40 +++++++++++---- .../generators/amazon_bedrock/generator.py | 43 +++++++++++----- integrations/amazon_bedrock/tests/conftest.py | 35 +++++++++++++ .../tests/test_amazon_bedrock.py | 51 +++++-------------- .../tests/test_amazon_chat_bedrock.py | 24 ++++----- 6 files changed, 128 insertions(+), 78 deletions(-) create mode 100644 integrations/amazon_bedrock/tests/conftest.py diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 8527d27a1..8366c9449 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -83,11 +83,13 @@ style = [ "ruff {args:.}", "black --check --diff {args:.}", ] + fmt = [ "black {args:.}", "ruff --fix {args:.}", "style", ] + all = [ "style", "typing", @@ -135,7 +137,7 @@ ignore = [ # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", # Ignore unused params - "ARG001", "ARG002", "ARG005" + "ARG001", "ARG002", "ARG005", ] unfixable = [ # Don't touch unused imports @@ -153,16 +155,13 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true parallel = true - -[tool.coverage.paths] -amazon_bedrock_haystack = ["src/*"] -tests = ["tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 804d44413..6ce671e68 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -8,6 +8,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.utils import deserialize_callback_handler from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack_integrations.components.generators.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -61,11 +62,13 @@ class AmazonBedrockChatGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, @@ -102,6 +105,11 @@ def __init__( msg = "'model' cannot be None or empty string" raise ValueError(msg) self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name # get the model adapter for the given model model_adapter_cls = self.get_model_adapter(model=model) @@ -111,13 +119,16 @@ def __init__( self.model_adapter = model_adapter_cls(generation_kwargs or {}) # create the AWS session and client + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + try: session = self.get_aws_session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_region_name=aws_region_name, - aws_profile_name=aws_profile_name, + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), ) self.client = session.client("bedrock-runtime") except Exception as exception: @@ -229,6 +240,11 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, stop_words=self.stop_words, generation_kwargs=self.model_adapter.generation_kwargs, @@ -246,4 +262,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 8e89dab59..1a8bb04f0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -6,6 +6,7 @@ import boto3 from botocore.exceptions import BotoCoreError, ClientError from haystack import component, default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace from .adapters import ( AI21LabsJurassic2Adapter, @@ -72,11 +73,13 @@ class AmazonBedrockGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + "AWS_SECRET_ACCESS_KEY", strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, **kwargs, ): @@ -85,14 +88,22 @@ def __init__( raise ValueError(msg) self.model = model self.max_length = max_length + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None try: session = self.get_aws_session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_region_name=aws_region_name, - aws_profile_name=aws_profile_name, + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), ) self.client = session.client("bedrock-runtime") except Exception as exception: @@ -103,8 +114,7 @@ def __init__( raise AmazonBedrockConfigurationError(msg) from exception model_input_kwargs = kwargs - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed + # We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed model_max_length = kwargs.get("model_max_length", 4096) # Truncate prompt if prompt tokens > model_max_length-max_length @@ -298,6 +308,11 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, ) @@ -309,4 +324,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": :param data: The dictionary representation of this component. :return: The deserialized component instance. """ + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/tests/conftest.py b/integrations/amazon_bedrock/tests/conftest.py new file mode 100644 index 000000000..4c8ce688c --- /dev/null +++ b/integrations/amazon_bedrock/tests/conftest.py @@ -0,0 +1,35 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def set_env_variables(monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token") + monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region") + monkeypatch.setenv("AWS_PROFILE", "some_fake_profile") + + +@pytest.fixture +def mock_auto_tokenizer(): + with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: + mock_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer + + +# create a fixture with mocked boto3 client and session +@pytest.fixture +def mock_boto3_session(): + with patch("boto3.Session") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_prompt_handler(): + with patch( + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" + ) as mock_prompt_handler: + yield mock_prompt_handler diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index b08e9dfd5..f4d10a9b2 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -16,47 +16,24 @@ from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError -@pytest.fixture -def mock_auto_tokenizer(): - with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: - mock_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer - - -# create a fixture with mocked boto3 client and session -@pytest.fixture -def mock_boto3_session(): - with patch("boto3.Session") as mock_client: - yield mock_client - - -@pytest.fixture -def mock_prompt_handler(): - with patch( - "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" - ) as mock_prompt_handler: - yield mock_prompt_handler - - @pytest.mark.unit -def test_to_dict(mock_auto_tokenizer, mock_boto3_session): +def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): """ Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockGenerator( model="anthropic.claude-v2", max_length=99, - aws_access_key_id="some_fake_id", - aws_secret_access_key="some_fake_key", - aws_session_token="some_fake_token", - aws_profile_name="some_fake_profile", - aws_region_name="fake_region", ) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, }, @@ -66,7 +43,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): @pytest.mark.unit -def test_from_dict(mock_auto_tokenizer, mock_boto3_session): +def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): """ Test that the from_dict method returns the correct object """ @@ -74,6 +51,11 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, }, @@ -85,7 +67,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): @pytest.mark.unit -def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): +def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables): """ Test that the default constructor sets the correct values """ @@ -93,11 +75,6 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): layer = AmazonBedrockGenerator( model="anthropic.claude-v2", max_length=99, - aws_access_key_id="some_fake_id", - aws_secret_access_key="some_fake_key", - aws_session_token="some_fake_token", - aws_profile_name="some_fake_profile", - aws_region_name="fake_region", ) assert layer.max_length == 99 @@ -120,7 +97,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): @pytest.mark.unit -def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session): +def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler): """ Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 """ diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py index 9592b5b39..574aab5cc 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -38,23 +38,23 @@ def mock_prompt_handler(): yield mock_prompt_handler -def test_to_dict(mock_auto_tokenizer, mock_boto3_session): +def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables): """ Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockChatGenerator( model="anthropic.claude-v2", - aws_access_key_id="some_fake_id", - aws_secret_access_key="some_fake_key", - aws_session_token="some_fake_token", - aws_profile_name="some_fake_profile", - aws_region_name="fake_region", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, ) expected_dict = { "type": clazz, "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], @@ -73,6 +73,11 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): { "type": clazz, "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -84,18 +89,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): assert generator.streaming_callback == print_streaming_chunk -def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): +def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables): """ Test that the default constructor sets the correct values """ layer = AmazonBedrockChatGenerator( model="anthropic.claude-v2", - aws_access_key_id="some_fake_id", - aws_secret_access_key="some_fake_key", - aws_session_token="some_fake_token", - aws_profile_name="some_fake_profile", - aws_region_name="fake_region", ) assert layer.model == "anthropic.claude-v2"