Skip to content

Commit

Permalink
Adopt Secret to Amazon Bedrock (#416)
Browse files Browse the repository at this point in the history
* initial import

* addin Secret and fixing tests

* cleaning

* using staticmethod directly

* removing ignore from B008 config and setting up in-line

* addin Secret and fixing tests for chat component

* addin Secret and fixing tests for the chat component

* fixing to_dict from_dict

* fixing to_dict from_dict to the chat component as well
  • Loading branch information
davidsbatista authored Feb 16, 2024
1 parent ec58d6f commit d28ffa1
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 78 deletions.
13 changes: 6 additions & 7 deletions integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]

fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]

all = [
"style",
"typing",
Expand Down Expand Up @@ -135,7 +137,7 @@ ignore = [
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore unused params
"ARG001", "ARG002", "ARG005"
"ARG001", "ARG002", "ARG005",
]
unfixable = [
# Don't touch unused imports
Expand All @@ -153,16 +155,13 @@ ban-relative-imports = "parents"
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true


[tool.coverage.paths]
amazon_bedrock_haystack = ["src/*"]
tests = ["tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.components.generators.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
Expand Down Expand Up @@ -61,11 +62,13 @@ class AmazonBedrockChatGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
Expand Down Expand Up @@ -102,6 +105,11 @@ def __init__(
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name

# get the model adapter for the given model
model_adapter_cls = self.get_model_adapter(model=model)
Expand All @@ -111,13 +119,16 @@ def __init__(
self.model_adapter = model_adapter_cls(generation_kwargs or {})

# create the AWS session and client
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand Down Expand Up @@ -229,6 +240,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
stop_words=self.stop_words,
generation_kwargs=self.model_adapter.generation_kwargs,
Expand All @@ -246,4 +262,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator":
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
from botocore.exceptions import BotoCoreError, ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from .adapters import (
AI21LabsJurassic2Adapter,
Expand Down Expand Up @@ -72,11 +73,13 @@ class AmazonBedrockGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
**kwargs,
):
Expand All @@ -85,14 +88,22 @@ def __init__(
raise ValueError(msg)
self.model = model
self.max_length = max_length
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand All @@ -103,8 +114,7 @@ def __init__(
raise AmazonBedrockConfigurationError(msg) from exception

model_input_kwargs = kwargs
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)

# Truncate prompt if prompt tokens > model_max_length-max_length
Expand Down Expand Up @@ -298,6 +308,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
max_length=self.max_length,
)
Expand All @@ -309,4 +324,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator":
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
35 changes: 35 additions & 0 deletions integrations/amazon_bedrock/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token")
monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region")
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler
51 changes: 14 additions & 37 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,24 @@
from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler


@pytest.mark.unit
def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
},
Expand All @@ -66,14 +43,19 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the from_dict method returns the correct object
"""
generator = AmazonBedrockGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
},
Expand All @@ -85,19 +67,14 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the default constructor sets the correct values
"""

layer = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

assert layer.max_length == 99
Expand All @@ -120,7 +97,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session):
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
Expand Down
Loading

0 comments on commit d28ffa1

Please sign in to comment.