diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 1edde3526..193332009 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -3,6 +3,7 @@ import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -87,6 +88,7 @@ def __init__( max_length: Optional[int] = 100, truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -102,6 +104,7 @@ def __init__( :param truncate: Whether to truncate the prompt or not. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -120,6 +123,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.streaming_callback = streaming_callback + self.boto3_config = boto3_config self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self.client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]: max_length=self.max_length, truncate=self.truncate, streaming_callback=callback_name, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 61ae9d6b4..2ccd5a3fa 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session): "truncate": False, "temperature": 10, "streaming_callback": None, + "boto3_config": None, }, } @@ -57,12 +58,16 @@ def test_from_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "boto3_config": { + "read_timeout": 1000, + }, }, } ) assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" + assert generator.boto3_config == {"read_timeout": 1000} def test_default_constructor(mock_boto3_session, set_env_variables):