Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt Secret to Amazon Bedrock #416

Merged
merged 19 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]

fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]

all = [
"style",
"typing",
Expand Down Expand Up @@ -135,7 +137,9 @@ ignore = [
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore unused params
"ARG001", "ARG002", "ARG005"
"ARG001", "ARG002", "ARG005",
# Ignore perform the call within the function
"B008"
]
unfixable = [
# Don't touch unused imports
Expand All @@ -153,16 +157,13 @@ ban-relative-imports = "parents"
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true


[tool.coverage.paths]
amazon_bedrock_haystack = ["src/*"]
tests = ["tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
from botocore.exceptions import BotoCoreError, ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import EnvVarSecret, Secret

from .adapters import (
AI21LabsJurassic2Adapter,
Expand Down Expand Up @@ -72,11 +73,21 @@ class AmazonBedrockGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = EnvVarSecret(env_vars=["AWS_ACCESS_KEY_ID"], strict=False).from_env_var(
"AWS_ACCESS_KEY_ID"
),
aws_secret_access_key: Optional[Secret] = EnvVarSecret(
env_vars=["AWS_SECRET_ACCESS_KEY"], strict=False
).from_env_var("AWS_SECRET_ACCESS_KEY"),
aws_session_token: Optional[Secret] = EnvVarSecret(env_vars=["AWS_SESSION_TOKEN"], strict=False).from_env_var(
"AWS_SESSION_TOKEN"
),
aws_region_name: Optional[Secret] = EnvVarSecret(env_vars=["AWS_DEFAULT_REGION"], strict=False).from_env_var(
"AWS_DEFAULT_REGION"
),
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
aws_profile_name: Optional[str] = EnvVarSecret(env_vars=["AWS_PROFILE"], strict=False).from_env_var(
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
"AWS_PROFILE"
),
max_length: Optional[int] = 100,
**kwargs,
):
Expand All @@ -86,13 +97,16 @@ def __init__(
self.model = model
self.max_length = max_length

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand All @@ -103,8 +117,7 @@ def __init__(
raise AmazonBedrockConfigurationError(msg) from exception

model_input_kwargs = kwargs
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)

# Truncate prompt if prompt tokens > model_max_length-max_length
Expand Down
35 changes: 35 additions & 0 deletions integrations/amazon_bedrock/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token")
monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region")
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler
46 changes: 10 additions & 36 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional, Type
from unittest.mock import MagicMock, call, patch

Expand All @@ -16,42 +17,14 @@
from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler


@pytest.mark.unit
def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

expected_dict = {
Expand All @@ -66,7 +39,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the from_dict method returns the correct object
"""
Expand All @@ -90,14 +63,15 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
Test that the default constructor sets the correct values
"""

os.environ["AWS_ACCESS_KEY_ID"] = "some_fake_id"
os.environ["AWS_SECRET_ACCESS_KEY"] = "some_fake_key"
os.environ["AWS_SESSION_TOKEN"] = "some_fake_token"
os.environ["AWS_DEFAULT_REGION"] = "fake_region"
os.environ["AWS_PROFILE"] = "some_fake_profile"

layer = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

assert layer.max_length == 99
Expand All @@ -120,7 +94,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):


@pytest.mark.unit
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session):
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
Expand Down