Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt Secret to Amazon Bedrock #416

Merged
merged 19 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ ignore = [
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore unused params
"ARG001", "ARG002", "ARG005",
# Ignore perform the call within the function
"B008"
]
unfixable = [
# Don't touch unused imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret

from haystack_integrations.components.generators.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
Expand Down Expand Up @@ -61,11 +62,13 @@ class AmazonBedrockChatGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
Expand Down Expand Up @@ -111,13 +114,16 @@ def __init__(
self.model_adapter = model_adapter_cls(generation_kwargs or {})

# create the AWS session and client
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import boto3
from botocore.exceptions import BotoCoreError, ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import EnvVarSecret, Secret
from haystack.utils.auth import Secret

from .adapters import (
AI21LabsJurassic2Adapter,
Expand Down Expand Up @@ -73,21 +73,13 @@ class AmazonBedrockGenerator:
def __init__(
self,
model: str,
aws_access_key_id: Optional[Secret] = EnvVarSecret(env_vars=["AWS_ACCESS_KEY_ID"], strict=False).from_env_var(
"AWS_ACCESS_KEY_ID"
),
aws_secret_access_key: Optional[Secret] = EnvVarSecret(
env_vars=["AWS_SECRET_ACCESS_KEY"], strict=False
).from_env_var("AWS_SECRET_ACCESS_KEY"),
aws_session_token: Optional[Secret] = EnvVarSecret(env_vars=["AWS_SESSION_TOKEN"], strict=False).from_env_var(
"AWS_SESSION_TOKEN"
),
aws_region_name: Optional[Secret] = EnvVarSecret(env_vars=["AWS_DEFAULT_REGION"], strict=False).from_env_var(
"AWS_DEFAULT_REGION"
),
aws_profile_name: Optional[str] = EnvVarSecret(env_vars=["AWS_PROFILE"], strict=False).from_env_var(
"AWS_PROFILE"
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
max_length: Optional[int] = 100,
**kwargs,
):
Expand Down
14 changes: 2 additions & 12 deletions integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,12 @@ def mock_prompt_handler():
yield mock_prompt_handler


def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockChatGenerator(
model="anthropic.claude-v2",
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
)
Expand Down Expand Up @@ -84,18 +79,13 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
assert generator.streaming_callback == print_streaming_chunk


def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"""
Test that the default constructor sets the correct values
"""

layer = AmazonBedrockChatGenerator(
model="anthropic.claude-v2",
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)

assert layer.model == "anthropic.claude-v2"
Expand Down