diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index b93ba1d3f..32d1de629 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -75,6 +75,7 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, + truncate: Optional[bool] = True, **kwargs, ): """ @@ -87,6 +88,7 @@ def __init__( :param aws_region_name: The AWS region name. :param aws_profile_name: The AWS profile name. :param max_length: The maximum length of the generated text. + :param truncate: Whether to truncate the prompt or not. :param kwargs: Additional keyword arguments to be passed to the model. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -97,11 +99,13 @@ def __init__( raise ValueError(msg) self.model = model self.max_length = max_length + self.truncate = truncate self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -129,6 +133,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, @@ -189,6 +194,9 @@ def invoke(self, *args, **kwargs): ) raise ValueError(msg) + if self.truncate: + prompt = self._ensure_token_limit(prompt) + body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) try: if stream: @@ -266,6 +274,8 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, + truncate=self.truncate, + **self.kwargs, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index e603c8853..65463caae 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -20,10 +20,7 @@ def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=99, - ) + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", @@ -35,6 +32,8 @@ def test_to_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "truncate": False, + "temperature": 10, }, } @@ -194,6 +193,46 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + long_prompt_text = "I am a tokenized prompt of length eight" + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockGenerator( + model="anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, + truncate=False, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator, "_ensure_token_limit", wraps=generator._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + generator.model_adapter.get_responses = MagicMock(return_value=["response"]) + + # Invoke the generator + generator.invoke(prompt=long_prompt_text) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text) + + @pytest.mark.parametrize( "model, expected_model_adapter", [ diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md new file mode 100644 index 000000000..9ae3da929 --- /dev/null +++ b/integrations/fastembed/CHANGELOG.md @@ -0,0 +1,63 @@ +# Changelog + +## [unreleased] + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +### Fix + +- Typo on Sparse embedders. The parameter should be "progress_bar" … (#814) + +## [integrations/fastembed-v1.1.0] - 2024-05-15 + +## [integrations/fastembed-v1.0.0] - 2024-05-06 + +## [integrations/fastembed-v0.1.0] - 2024-04-10 + +### 🚀 Features + +- *(FastEmbed)* Support for SPLADE Sparse Embedder (#579) + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/fastembed-v0.0.6] - 2024-03-07 + +### 📚 Documentation + +- Review and normalize docstrings - `integrations.fastembed` (#519) +- Small consistency improvements (#536) + +## [integrations/fastembed-v0.0.5] - 2024-02-20 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +## [integrations/fastembed-v0.0.4] - 2024-02-16 + +## [integrations/fastembed-v0.0.3] - 2024-02-12 + +### 🐛 Bug Fixes + +- From numpy float to float (#391) + +### 📚 Documentation + +- Update paths and titles (#397) + +## [integrations/fastembed-v0.0.2] - 2024-02-11 + +## [integrations/fastembed-v0.0.1] - 2024-02-10 + + diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 910667db0..ac00be6d3 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -874,7 +874,7 @@ def _handle_duplicate_documents( :param documents: A list of Haystack Document objects. :param index: name of the index - :param duplicate_documents: The duplicate policy to use when writing documents. + :param policy: The duplicate policy to use when writing documents. :returns: A list of Haystack Document objects. """