Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt Secret to Amazon Bedrock #416

Merged
merged 19 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]

fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]

all = [
"style",
"typing",
Expand Down Expand Up @@ -135,7 +137,7 @@ ignore = [
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore unused params
"ARG001", "ARG002", "ARG005"
"ARG001", "ARG002", "ARG005",
]
unfixable = [
# Don't touch unused imports
Expand All @@ -153,16 +155,13 @@ ban-relative-imports = "parents"
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true


[tool.coverage.paths]
amazon_bedrock_haystack = ["src/*"]
tests = ["tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.components.generators.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
Expand Down Expand Up @@ -61,11 +62,13 @@ class AmazonBedrockChatGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
Expand Down Expand Up @@ -102,6 +105,11 @@ def __init__(
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name

# get the model adapter for the given model
model_adapter_cls = self.get_model_adapter(model=model)
Expand All @@ -111,13 +119,16 @@ def __init__(
self.model_adapter = model_adapter_cls(generation_kwargs or {})

# create the AWS session and client
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand Down Expand Up @@ -229,6 +240,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
stop_words=self.stop_words,
generation_kwargs=self.model_adapter.generation_kwargs,
Expand All @@ -246,4 +262,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator":
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
from botocore.exceptions import BotoCoreError, ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from .adapters import (
AI21LabsJurassic2Adapter,
Expand Down Expand Up @@ -72,11 +73,13 @@ class AmazonBedrockGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
"AWS_SECRET_ACCESS_KEY", strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
**kwargs,
):
Expand All @@ -85,14 +88,22 @@ def __init__(
raise ValueError(msg)
self.model = model
self.max_length = max_length
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand All @@ -103,8 +114,7 @@ def __init__(
raise AmazonBedrockConfigurationError(msg) from exception

model_input_kwargs = kwargs
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)

# Truncate prompt if prompt tokens > model_max_length-max_length
Expand Down Expand Up @@ -298,6 +308,11 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
max_length=self.max_length,
)
Expand All @@ -309,4 +324,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator":
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
35 changes: 35 additions & 0 deletions integrations/amazon_bedrock/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token")
monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region")
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler
51 changes: 14 additions & 37 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,24 @@
from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler


@pytest.mark.unit
def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
},
Expand All @@ -66,14 +43,19 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the from_dict method returns the correct object
"""
generator = AmazonBedrockGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
},
Expand All @@ -85,19 +67,14 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the default constructor sets the correct values
"""

layer = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

assert layer.max_length == 99
Expand All @@ -120,7 +97,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session):
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
Expand Down
Loading