Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid bedrock read timeout (add boto3_config param) #1135

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type

from botocore.config import Config
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand All @@ -102,6 +104,7 @@ def __init__(
:param truncate: Whether to truncate the prompt or not.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
:param kwargs: Additional keyword arguments to be passed to the model.
These arguments are specific to the model. You can find them in the model's documentation.
:raises ValueError: If the model name is empty or None.
Expand All @@ -120,6 +123,7 @@ def __init__(
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.streaming_callback = streaming_callback
self.boto3_config = boto3_config
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
Expand All @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self.client = session.client("bedrock-runtime")
config: Optional[Config] = None
if self.boto3_config:
config = Config(**self.boto3_config)
self.client = session.client("bedrock-runtime", config=config)
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
Expand Down Expand Up @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]:
max_length=self.max_length,
truncate=self.truncate,
streaming_callback=callback_name,
boto3_config=self.boto3_config,
**self.kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session):
"truncate": False,
"temperature": 10,
"streaming_callback": None,
"boto3_config": None,
},
}

Expand All @@ -57,12 +58,16 @@ def test_from_dict(mock_boto3_session):
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
"boto3_config": {
"read_timeout": 1000,
},
},
}
)

assert generator.max_length == 99
assert generator.model == "anthropic.claude-v2"
assert generator.boto3_config == {"read_timeout": 1000}


def test_default_constructor(mock_boto3_session, set_env_variables):
Expand Down