Skip to content

Commit

Permalink
fix: avoid bedrock read timeout (add boto3_config param) (#1135)
Browse files Browse the repository at this point in the history
* fix: avoid bedrock read timeout

* fix lint

* fix test

* add from_dict test
  • Loading branch information
tstadel authored Oct 16, 2024
1 parent a04dae9 commit 9a3c2e0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -102,6 +104,7 @@ def __init__(
:param truncate: Whether to truncate the prompt or not.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional keyword arguments to be passed to the model.
These arguments are specific to the model. You can find them in the model's documentation.
:raises ValueError: If the model name is empty or None.
Expand All @@ -120,6 +123,7 @@ def __init__(
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.streaming_callback = streaming_callback
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
Expand All @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self.client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
Expand Down Expand Up @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]:
max_length=self.max_length,
truncate=self.truncate,
streaming_callback=callback_name,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session):
"truncate": False,
"temperature": 10,
"streaming_callback": None,
"boto3_config": None,
},
}

Expand All @@ -57,12 +58,16 @@ def test_from_dict(mock_boto3_session):
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
"boto3_config": {
"read_timeout": 1000,
},
},
}
)

assert generator.max_length == 99
assert generator.model == "anthropic.claude-v2"
assert generator.boto3_config == {"read_timeout": 1000}


def test_default_constructor(mock_boto3_session, set_env_variables):
Expand Down

0 comments on commit 9a3c2e0

Please sign in to comment.