Skip to content

Commit

Permalink
Merge branch 'main' into qdrant-scale-score-false
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jul 2, 2024
2 parents c84fa48 + 06d7776 commit ab02114
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
**kwargs,
):
"""
Expand All @@ -87,6 +88,7 @@ def __init__(
:param aws_region_name: The AWS region name.
:param aws_profile_name: The AWS profile name.
:param max_length: The maximum length of the generated text.
:param truncate: Whether to truncate the prompt or not.
:param kwargs: Additional keyword arguments to be passed to the model.
:raises ValueError: If the model name is empty or None.
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
Expand All @@ -97,11 +99,13 @@ def __init__(
raise ValueError(msg)
self.model = model
self.max_length = max_length
self.truncate = truncate
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None
Expand Down Expand Up @@ -129,6 +133,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation

self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
Expand Down Expand Up @@ -189,6 +194,9 @@ def invoke(self, *args, **kwargs):
)
raise ValueError(msg)

if self.truncate:
prompt = self._ensure_token_limit(prompt)

body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
try:
if stream:
Expand Down Expand Up @@ -266,6 +274,8 @@ def to_dict(self) -> Dict[str, Any]:
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
max_length=self.max_length,
truncate=self.truncate,
**self.kwargs,
)

@classmethod
Expand Down
47 changes: 43 additions & 4 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ def test_to_dict(mock_boto3_session):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
)
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
Expand All @@ -35,6 +32,8 @@ def test_to_dict(mock_boto3_session):
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
"truncate": False,
"temperature": 10,
},
}

Expand Down Expand Up @@ -194,6 +193,46 @@ def test_long_prompt_is_truncated(mock_boto3_session):
assert prompt_after_resize == truncated_prompt_text


def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
long_prompt_text = "I am a tokenized prompt of length eight"

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
truncate=False,
)

# Mock the _ensure_token_limit method to track if it is called
with patch.object(
generator, "_ensure_token_limit", wraps=generator._ensure_token_limit
) as mock_ensure_token_limit:
# Mock the model adapter to avoid actual invocation
generator.model_adapter.prepare_body = MagicMock(return_value={})
generator.client = MagicMock()
generator.client.invoke_model = MagicMock(
return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
)
generator.model_adapter.get_responses = MagicMock(return_value=["response"])

# Invoke the generator
generator.invoke(prompt=long_prompt_text)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text)


@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down
63 changes: 63 additions & 0 deletions integrations/fastembed/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Changelog

## [unreleased]

### ⚙️ Miscellaneous Tasks

- Retry tests to reduce flakyness (#836)
- Update ruff invocation to include check parameter (#853)

### Fix

- Typo on Sparse embedders. The parameter should be "progress_bar" … (#814)

## [integrations/fastembed-v1.1.0] - 2024-05-15

## [integrations/fastembed-v1.0.0] - 2024-05-06

## [integrations/fastembed-v0.1.0] - 2024-04-10

### 🚀 Features

- *(FastEmbed)* Support for SPLADE Sparse Embedder (#579)

### 📚 Documentation

- Disable-class-def (#556)

## [integrations/fastembed-v0.0.6] - 2024-03-07

### 📚 Documentation

- Review and normalize docstrings - `integrations.fastembed` (#519)
- Small consistency improvements (#536)

## [integrations/fastembed-v0.0.5] - 2024-02-20

### 🐛 Bug Fixes

- Fix order of API docs (#447)

This PR will also push the docs to Readme

### 📚 Documentation

- Update category slug (#442)

## [integrations/fastembed-v0.0.4] - 2024-02-16

## [integrations/fastembed-v0.0.3] - 2024-02-12

### 🐛 Bug Fixes

- From numpy float to float (#391)

### 📚 Documentation

- Update paths and titles (#397)

## [integrations/fastembed-v0.0.2] - 2024-02-11

## [integrations/fastembed-v0.0.1] - 2024-02-10

<!-- generated by git-cliff -->
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def _handle_duplicate_documents(
:param documents: A list of Haystack Document objects.
:param index: name of the index
:param duplicate_documents: The duplicate policy to use when writing documents.
:param policy: The duplicate policy to use when writing documents.
:returns: A list of Haystack Document objects.
"""

Expand Down

0 comments on commit ab02114

Please sign in to comment.