Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow passing boto3 config to all AWS Bedrock classes #1166

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -98,6 +100,7 @@ def __init__(
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
:raises ValueError: If the model is not supported.
Expand All @@ -110,6 +113,19 @@ def __init__(
)
raise ValueError(msg)

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

Expand All @@ -121,26 +137,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self._client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.kwargs = kwargs

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Expand Down Expand Up @@ -269,6 +276,7 @@ def to_dict(self) -> Dict[str, Any]:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -81,6 +83,7 @@ def __init__(
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for
Cohere models.
:raises ValueError: If the model is not supported.
Expand All @@ -92,6 +95,15 @@ def __init__(
)
raise ValueError(msg)

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

Expand All @@ -103,22 +115,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self._client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.model = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

@component.output_types(embedding=List[float])
def run(self, text: str):
"""Embeds the input text using the Amazon Bedrock model.
Expand Down Expand Up @@ -185,6 +192,7 @@ def to_dict(self) -> Dict[str, Any]:
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
truncate: Optional[bool] = True,
boto3_config: Optional[Dict[str, Any]] = None,
):
"""
Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
Expand Down Expand Up @@ -110,6 +112,11 @@ def __init__(
[StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and
switches the streaming mode on.
:param truncate: Whether to truncate the prompt messages or not.
:param boto3_config: The configuration for the boto3 client.

:raises ValueError: If the model name is empty or None.
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
not supported.
"""
if not model:
msg = "'model' cannot be None or empty string"
Expand All @@ -120,7 +127,10 @@ def __init__(
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.stop_words = stop_words or []
self.streaming_callback = streaming_callback
self.truncate = truncate
self.boto3_config = boto3_config

# get the model adapter for the given model
model_adapter_cls = self.get_model_adapter(model=model)
Expand All @@ -141,17 +151,17 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self.client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

self.stop_words = stop_words or []
self.streaming_callback = streaming_callback

@component.output_types(replies=List[ChatMessage])
def run(
self,
Expand Down Expand Up @@ -256,6 +266,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.model_adapter.generation_kwargs,
streaming_callback=callback_name,
truncate=self.truncate,
boto3_config=self.boto3_config,
)

@classmethod
Expand Down
28 changes: 25 additions & 3 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import Optional, Type
from typing import Any, Dict, Optional, Type
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -26,14 +26,24 @@
]


def test_to_dict(mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockChatGenerator(
model="anthropic.claude-v2",
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
boto3_config=boto3_config,
)
expected_dict = {
"type": KLASS,
Expand All @@ -48,13 +58,23 @@ def test_to_dict(mock_boto3_session):
"stop_words": [],
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"truncate": True,
"boto3_config": boto3_config,
},
}

assert generator.to_dict() == expected_dict


def test_from_dict(mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
"""
Test that the from_dict method returns the correct object
"""
Expand All @@ -71,12 +91,14 @@ def test_from_dict(mock_boto3_session):
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"truncate": True,
"boto3_config": boto3_config,
},
}
)
assert generator.model == "anthropic.claude-v2"
assert generator.model_adapter.generation_kwargs == {"temperature": 0.7}
assert generator.streaming_callback == print_streaming_chunk
assert generator.boto3_config == boto3_config


def test_default_constructor(mock_boto3_session, set_env_variables):
Expand Down
27 changes: 25 additions & 2 deletions integrations/amazon_bedrock/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from typing import Any, Dict, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -66,10 +67,20 @@ def test_connection_error(self, mock_boto3_session):
input_type="fake_input_type",
)

def test_to_dict(self, mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_to_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
embedder = AmazonBedrockDocumentEmbedder(
model="cohere.embed-english-v3",
input_type="search_document",
boto3_config=boto3_config,
)

expected_dict = {
Expand All @@ -86,12 +97,22 @@ def test_to_dict(self, mock_boto3_session):
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"boto3_config": boto3_config,
},
}

assert embedder.to_dict() == expected_dict

def test_from_dict(self, mock_boto3_session):
@pytest.mark.parametrize(
"boto3_config",
[
None,
{
"read_timeout": 1000,
},
],
)
def test_from_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
data = {
"type": TYPE,
"init_parameters": {
Expand All @@ -106,6 +127,7 @@ def test_from_dict(self, mock_boto3_session):
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"boto3_config": boto3_config,
},
}

Expand All @@ -117,6 +139,7 @@ def test_from_dict(self, mock_boto3_session):
assert embedder.progress_bar
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.boto3_config == boto3_config

def test_init_invalid_model(self):
with pytest.raises(ValueError):
Expand Down
Loading
Loading