From 252d27df70f90c933fd9a870b5293da04370f29d Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:16:07 -0800 Subject: [PATCH 01/16] new azure retrievers --- .../retrievers/azure_ai_search/__init__.py | 4 +- .../azure_ai_search/bm25_retriever.py | 116 +++++++++++++++++ .../azure_ai_search/hybrid_retriever.py | 120 ++++++++++++++++++ .../azure_ai_search/document_store.py | 52 ++++++++ 4 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eb75ffa6c..eebe990f3 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,3 +1,5 @@ from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .bm25_retriever import AzureAISearchBM25Retriever +from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever"] +__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..65e273b73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..eb28c4e73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,120 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..4e3b3b44a 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -438,3 +438,55 @@ def _embedding_retrieval( result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query` is an empty string + :returns: List of Document that are most similar to `query` + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) \ No newline at end of file From 0f38ea8186f1316e9f121b35c2705ed3a6c69882 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:34:18 -0800 Subject: [PATCH 02/16] fix styling --- .../retrievers/azure_ai_search/__init__.py | 4 ++-- .../retrievers/azure_ai_search/hybrid_retriever.py | 6 +++++- .../azure_ai_search/document_store.py | 12 +++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eebe990f3..56dc30db4 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,5 +1,5 @@ -from .embedding_retriever import AzureAISearchEmbeddingRetriever from .bm25_retriever import AzureAISearchBM25Retriever +from .embedding_retriever import AzureAISearchEmbeddingRetriever from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"] +__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index eb28c4e73..77cc0c586 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -86,7 +86,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": @component.output_types(documents=List[Document]) def run( - self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 4e3b3b44a..737a0bd11 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -438,7 +438,7 @@ def _embedding_retrieval( result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) - + def _bm25_retrieval( self, query: str, @@ -487,6 +487,12 @@ def _hybrid_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple") + result = self.client.search( + search_text=query, + vector_queries=[vector_query], + select=fields, filter=filters, + top=top_k, query_type="simple" + ) azure_docs = list(result) - return self._convert_search_result_to_documents(azure_docs) \ No newline at end of file + return self._convert_search_result_to_documents(azure_docs) + \ No newline at end of file From 41672f7a81e685489dc5942cef8afb5b0b674fcc Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:39:16 -0800 Subject: [PATCH 03/16] fix whitespace --- .../retrievers/azure_ai_search/hybrid_retriever.py | 8 ++++---- .../document_stores/azure_ai_search/document_store.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index 77cc0c586..ace032a8d 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -86,10 +86,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": @component.output_types(documents=List[Document]) def run( - self, - query: str, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 737a0bd11..3f4427909 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -488,9 +488,9 @@ def _hybrid_retrieval( vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") result = self.client.search( - search_text=query, - vector_queries=[vector_query], - select=fields, filter=filters, + search_text=query, + vector_queries=[vector_query], + select=fields, filter=filters, top=top_k, query_type="simple" ) azure_docs = list(result) From 914d27fbd77d94eb7a5151d2c3d8557800d85bcd Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:40:56 -0800 Subject: [PATCH 04/16] fix whitespace --- .../document_stores/azure_ai_search/document_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 3f4427909..05d92f523 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -495,4 +495,3 @@ def _hybrid_retrieval( ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) - \ No newline at end of file From ee41bf0e3660457a2ef22f8919f6f4fdf7d7f68d Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:47:10 -0800 Subject: [PATCH 05/16] fix linting --- .../retrievers/azure_ai_search/hybrid_retriever.py | 2 +- .../document_stores/azure_ai_search/document_store.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index ace032a8d..fbe7752aa 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -90,7 +90,7 @@ def run( query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None + top_k: Optional[int] = None, ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 05d92f523..68a2db22e 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -490,8 +490,10 @@ def _hybrid_retrieval( result = self.client.search( search_text=query, vector_queries=[vector_query], - select=fields, filter=filters, - top=top_k, query_type="simple" + select=fields, + filter=filters, + top=top_k, + query_type="simple", ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) From 1cb63f489b4a5be63b1140f37bf2e4e3abe7896e Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:49:41 -0800 Subject: [PATCH 06/16] tests for bm25 and hybrid retrievers --- .../tests/test_bm25_retriever.py | 128 ++++++++++++++++ .../tests/test_hybrid_retriever.py | 145 ++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 integrations/azure_ai_search/tests/test_bm25_retriever.py create mode 100644 integrations/azure_ai_search/tests/test_hybrid_retriever.py diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py new file mode 100644 index 000000000..d0c6d0da9 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "metadata_fields": None, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchBM25Retriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1", content="Test document")] + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.run(query="Test document") + assert res["documents"] == docs + + def test_document_retrieval(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document(content="This is first document"), + Document(content="This is second document"), + Document(content="This is third document"), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + results = retriever.run(query="This is first document") + assert results["documents"][0].content == "This is first document" diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py new file mode 100644 index 000000000..2447949fd --- /dev/null +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchHybridRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.run(query="Test document", query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is third document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + results = retriever.run(query="This is first document", query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) From 147e6e12b767527958cca8f8c6a6b75f1670f981 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 14 Nov 2024 10:48:22 +0100 Subject: [PATCH 07/16] fix: deepeval - pin indirect dependencies based on python version (#1187) * try pinning pydantic * retry * again * more precise pin * fix * better --- integrations/deepeval/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 6ef64387b..78cc2542a 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "deepeval==0.20.57"] +dependencies = ["haystack-ai", "deepeval==0.20.57", "langchain<0.3; python_version < '3.10'"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/deepeval" From 7342a8943afa6fb57705fd9ad8cdddbb53d3b382 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 09:50:16 +0000 Subject: [PATCH 08/16] Update the changelog --- integrations/deepeval/CHANGELOG.md | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 integrations/deepeval/CHANGELOG.md diff --git a/integrations/deepeval/CHANGELOG.md b/integrations/deepeval/CHANGELOG.md new file mode 100644 index 000000000..a296c7cfa --- /dev/null +++ b/integrations/deepeval/CHANGELOG.md @@ -0,0 +1,35 @@ +# Changelog + +## [integrations/deepeval-v0.1.2] - 2024-11-14 + +### ๐Ÿš€ Features + +- Implement `DeepEvalEvaluator` (#346) + +### ๐Ÿ› Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Deepeval - pin indirect dependencies based on python version (#1187) + +### ๐Ÿ“š Documentation + +- Update paths and titles (#397) +- Update category slug (#442) +- Update `deepeval-haystack` docstrings (#527) +- Disable-class-def (#556) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Exculde evaluator private classes in API docs (#392) +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + + From 2e0cddb6b74613a2be1e44c8e804539e971d4d04 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 14 Nov 2024 12:23:23 +0100 Subject: [PATCH 09/16] fix: VertexAIGeminiGenerator - remove support for tools and change output type (#1180) --- .../generators/google_vertex/gemini.py | 63 ++------- .../google_vertex/tests/test_gemini.py | 125 ++---------------- 2 files changed, 21 insertions(+), 167 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 737f2e668..c9473b428 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -15,8 +15,6 @@ HarmBlockThreshold, HarmCategory, Part, - Tool, - ToolConfig, ) logger = logging.getLogger(__name__) @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs or "tool_config" in kwargs: + msg = ( + "VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. " + "Use VertexAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -58,8 +66,6 @@ def __init__( location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, - tool_config: Optional[ToolConfig] = None, system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -86,10 +92,6 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. - :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -105,8 +107,6 @@ def __init__( # model parameters self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._tool_config = tool_config self._system_instruction = system_instruction self._streaming_callback = streaming_callback @@ -115,8 +115,6 @@ def __init__( self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, ) @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: - """Serializes the ToolConfig object into a dictionary.""" - - mode = tool_config._gapic_tool_config.function_calling_config.mode - allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names - config_dict = {"function_calling_config": {"mode": mode}} - - if allowed_function_names: - config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names - - return config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": Deserialized component. """ - def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: - """Deserializes the ToolConfig object from a dictionary.""" - function_calling_config = config_dict["function_calling_config"] - return ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=function_calling_config["mode"], - allowed_function_names=function_calling_config.get("allowed_function_names"), - ) - ) - - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 277851224..ff692c6f4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,38 +1,17 @@ from unittest.mock import MagicMock, Mock, patch +import pytest from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk from vertexai.generative_models import ( - FunctionDeclaration, GenerationConfig, HarmBlockThreshold, HarmCategory, - Tool, - ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config assert gemini._system_instruction == "Please provide brief answers." +def test_init_fails_with_tools_or_tool_config(): + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tools=["tool1", "tool2"]) + + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tool_config={"custom": "config"}) + + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, - "tool_config": None, "system_instruction": None, }, } @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - "property_ordering": ["location", "unit"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, - "tool_config": None, "system_instruction": None, }, } @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id is None assert gemini._location is None assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None assert gemini._system_instruction is None assert gemini._generation_config is None @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - "description": "Get the current weather in a given location", - } - ] - } - ], "streaming_callback": None, - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._tool_config, ToolConfig) assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") From 018b13e0bb58a13a8cb36b86d1c56aa0cb837ef8 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 11:26:14 +0000 Subject: [PATCH 10/16] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index ed2cc3c3b..ea2a8fb18 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/google_vertex-v3.0.0] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- VertexAIGeminiGenerator - remove support for tools and change output type (#1180) + +### โš™๏ธ Miscellaneous Tasks + +- Fix Vertex tests (#1163) + ## [integrations/google_vertex-v2.2.0] - 2024-10-23 ### ๐Ÿ› Bug Fixes From f1fd742998f2f47bec8eac9785596e358fd0b9c5 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 14 Nov 2024 14:53:54 +0100 Subject: [PATCH 11/16] fix: Fixes to NvidiaRanker (#1191) * Fixes to NvidiaRanker * Add inits and headers * More headers * updates * Reactivate test * Fix tests * Reenable test and add test --- integrations/nvidia/pyproject.toml | 2 +- .../src/haystack_integrations/__init__.py | 3 ++ .../components/__init__.py | 3 ++ .../components/embedders/__init__.py | 3 ++ .../components/embedders/nvidia/__init__.py | 4 ++ .../embedders/nvidia/document_embedder.py | 11 +++-- .../embedders/nvidia/text_embedder.py | 7 ++- .../components/embedders/nvidia/truncate.py | 4 ++ .../components/generators/__init__.py | 3 ++ .../components/generators/nvidia/__init__.py | 1 + .../components/generators/nvidia/generator.py | 1 + .../components/rankers/__init__.py | 3 ++ .../components/rankers/nvidia/__init__.py | 4 ++ .../components/rankers/nvidia/ranker.py | 27 ++++++---- .../components/rankers/nvidia/truncate.py | 4 ++ .../haystack_integrations/utils/__init__.py | 3 ++ .../utils/nvidia/__init__.py | 4 ++ .../utils/nvidia/nim_backend.py | 4 ++ .../utils/nvidia/utils.py | 8 ++- integrations/nvidia/tests/__init__.py | 1 + integrations/nvidia/tests/conftest.py | 4 ++ integrations/nvidia/tests/test_base_url.py | 4 ++ .../nvidia/tests/test_document_embedder.py | 27 ++++++++-- .../tests/test_embedding_truncate_mode.py | 4 ++ integrations/nvidia/tests/test_generator.py | 1 + integrations/nvidia/tests/test_ranker.py | 49 +++++++++++++++++++ .../nvidia/tests/test_text_embedder.py | 21 +++++++- 27 files changed, 188 insertions(+), 22 deletions(-) create mode 100644 integrations/nvidia/src/haystack_integrations/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/utils/__init__.py diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index 7f0048c1b..586b50848 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "requests"] +dependencies = ["haystack-ai", "requests", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme" diff --git a/integrations/nvidia/src/haystack_integrations/__init__.py b/integrations/nvidia/src/haystack_integrations/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/__init__.py b/integrations/nvidia/src/haystack_integrations/components/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index bc2d9372c..827ad7dc6 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .document_embedder import NvidiaDocumentEmbedder from .text_embedder import NvidiaTextEmbedder from .truncate import EmbeddingTruncateMode diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index d746a75f4..606ec78fd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -5,10 +9,9 @@ from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -167,7 +170,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 22bed8197..4b7072f33 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -1,13 +1,16 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 3a8eb9d07..931c3cce3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index 18354ea17..b809d83b9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .generator import NvidiaGenerator __all__ = ["NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 3eadcc5df..5bf71a9e1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py index 29cb2f7f5..05daa1c54 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .ranker import NvidiaRanker __all__ = ["NvidiaRanker"] diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 1553d1ac3..9938b37d1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -1,12 +1,17 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode from haystack_integrations.utils.nvidia import NimBackend, url_validation -from .truncate import RankerTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" @@ -51,7 +56,7 @@ def __init__( model: Optional[str] = None, truncate: Optional[Union[RankerTruncateMode, str]] = None, api_url: Optional[str] = None, - api_key: Optional[Secret] = None, + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), top_k: int = 5, ): """ @@ -100,6 +105,7 @@ def __init__( self._api_key = Secret.from_env_var("NVIDIA_API_KEY") self._top_k = top_k self._initialized = False + self._backend: Optional[Any] = None def to_dict(self) -> Dict[str, Any]: """ @@ -113,7 +119,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, truncate=self._truncate, api_url=self._api_url, - api_key=self._api_key, + api_key=self._api_key.to_dict() if self._api_key else None, ) @classmethod @@ -124,7 +130,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": :param data: A dictionary containing the ranker's attributes. :returns: The deserialized ranker. """ - deserialize_secrets_inplace(data, keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def warm_up(self): @@ -170,16 +178,16 @@ def run( msg = "The ranker has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) if not isinstance(query, str): - msg = "Ranker expects the `query` parameter to be a string." + msg = "NvidiaRanker expects the `query` parameter to be a string." raise TypeError(msg) if not isinstance(documents, list): - msg = "Ranker expects the `documents` parameter to be a list." + msg = "NvidiaRanker expects the `documents` parameter to be a list." raise TypeError(msg) if not all(isinstance(doc, Document) for doc in documents): - msg = "Ranker expects the `documents` parameter to be a list of Document objects." + msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects." raise TypeError(msg) if top_k is not None and not isinstance(top_k, int): - msg = "Ranker expects the `top_k` parameter to be an integer." + msg = "NvidiaRanker expects the `top_k` parameter to be an integer." raise TypeError(msg) if len(documents) == 0: @@ -187,6 +195,7 @@ def run( top_k = top_k if top_k is not None else self._top_k if top_k < 1: + logger.warning("top_k should be at least 1, returning nothing") warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) return {"documents": []} diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py index 3b5d7f40a..649ceaf9d 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/utils/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index da301d29d..f08cda6cd 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .nim_backend import Model, NimBackend from .utils import is_hosted, url_validation diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index cbb6b7c3f..0279cf608 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 7d4dfc3b4..f07989405 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -1,9 +1,13 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings -from typing import List +from typing import List, Optional from urllib.parse import urlparse, urlunparse -def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: +def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str: """ Validate and normalize an API URL. diff --git a/integrations/nvidia/tests/__init__.py b/integrations/nvidia/tests/__init__.py index 47611e0b9..38adc654d 100644 --- a/integrations/nvidia/tests/__init__.py +++ b/integrations/nvidia/tests/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .conftest import MockBackend __all__ = ["MockBackend"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py index a6c78ba4e..b6346c672 100644 --- a/integrations/nvidia/tests/conftest.py +++ b/integrations/nvidia/tests/conftest.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict, List, Optional, Tuple import pytest diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 426bacc25..506fbc385 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index db69053e7..7e0e02f3d 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -104,7 +108,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", @@ -122,15 +126,32 @@ def from_dict(self, monkeypatch): }, } component = NvidiaDocumentEmbedder.from_dict(data) - assert component.model == "nvolveqa_40k" + assert component.model == "playground_nvolveqa_40k" assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" + assert component.batch_size == 10 + assert component.progress_bar is False + assert component.meta_fields_to_embed == ["test_field"] + assert component.embedding_separator == " | " + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", + "init_parameters": {}, + } + component = NvidiaDocumentEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" assert component.batch_size == 32 assert component.progress_bar assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" - assert component.truncate == EmbeddingTruncateMode.START + assert component.truncate is None def test_prepare_texts_to_embed_w_metadata(self): documents = [ diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py index e74d0308c..16f9112ea 100644 --- a/integrations/nvidia/tests/test_embedding_truncate_mode.py +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 0bd8b1fc6..055830ae5 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import os import pytest diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py index 566fd18a8..d66bb0f65 100644 --- a/integrations/nvidia/tests/test_ranker.py +++ b/integrations/nvidia/tests/test_ranker.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import re from typing import Any, Optional, Union @@ -256,3 +260,48 @@ def test_warm_up_once(self, monkeypatch) -> None: backend = client._backend client.warm_up() assert backend == client._backend + + def test_to_dict(self) -> None: + client = NvidiaRanker() + assert client.to_dict() == { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + + def test_from_dict(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + + def test_from_dict_defaults(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": {}, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 8690de6b1..278fa5191 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -77,7 +81,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", @@ -95,7 +99,20 @@ def from_dict(self, monkeypatch): assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" - assert component.truncate == "START" + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", + "init_parameters": {}, + } + component = NvidiaTextEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" + assert component.truncate is None @pytest.mark.usefixtures("mock_local_models") def test_run_default_model(self): From 82256bbf24d6f1c6ca2b1dc574390f75a4374539 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 14:00:34 +0000 Subject: [PATCH 12/16] Update the changelog --- integrations/nvidia/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index 75b31d033..a536e431d 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/nvidia-v0.1.1] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- Fixes to NvidiaRanker (#1191) + ## [integrations/nvidia-v0.1.0] - 2024-11-13 ### ๐Ÿš€ Features From 0640b240e087dc2b7aad9fe9add1615bc484706e Mon Sep 17 00:00:00 2001 From: rblst Date: Thu, 14 Nov 2024 19:09:24 +0100 Subject: [PATCH 13/16] feat: Add schema support to pgvector document store. (#1095) * Add schema support for the pgvector document store. Using the public schema of a PostgreSQL database is an anti-pattern. This change adds support for using a schema other than the public schema to create tables. * Fix long lines. * Fix long lines. Remove trailing spaces. * Fix trailing spaces. * Fix last trailing space. * Fix ruff issues. * Fix trailing space. * small fixes --------- Co-authored-by: Stefano Fiorucci --- .../pgvector/document_store.py | 70 +++++++++++++------ .../pgvector/tests/test_document_store.py | 3 + .../pgvector/tests/test_retrievers.py | 2 + 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 1b1333f5c..8e9c0f2fc 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id VARCHAR(128) PRIMARY KEY, embedding VECTOR({embedding_dimension}), content TEXT, @@ -36,7 +36,7 @@ """ INSERT_STATEMENT = """ -INSERT INTO {table_name} +INSERT INTO {schema_name}.{table_name} (id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) """ @@ -54,7 +54,7 @@ KEYWORD_QUERY = """ SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score -FROM {table_name}, plainto_tsquery({language}, %s) query +FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query WHERE to_tsvector({language}, content) @@ query """ @@ -78,6 +78,7 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, @@ -101,6 +102,7 @@ def __init__( e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) for more details. + :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: @@ -137,6 +139,7 @@ def __init__( self.connection_string = connection_string self.table_name = table_name + self.schema_name = schema_name self.embedding_dimension = embedding_dimension if vector_function not in VALID_VECTOR_FUNCTIONS: msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" @@ -207,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -266,7 +270,9 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + embedding_dimension=SQLLiteral(self.embedding_dimension), ) self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") @@ -274,12 +280,18 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. - The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. + The name of the schema (`schema_name`) and the name of the table (`table_name`) + are defined when initializing the `PgvectorDocumentStore`. """ + delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + ) - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - - self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + self._execute_sql( + delete_sql, + error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore", + ) def _create_keyword_index_if_not_exists(self): """ @@ -287,15 +299,16 @@ def _create_keyword_index_if_not_exists(self): """ index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.keyword_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.keyword_index_name), "Could not check if keyword index exists", ).fetchone() ) sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))" ).format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.keyword_index_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), @@ -318,8 +331,8 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.hnsw_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -349,8 +362,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) " + ).format( + schema_name=Identifier(self.schema_name), + index_name=Identifier(self.hnsw_index_name), + table_name=Identifier(self.table_name), + ops=SQL(pg_ops), ) if actual_hnsw_index_creation_kwargs: @@ -369,7 +387,9 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ 0 @@ -395,7 +415,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." raise ValueError(msg) - sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) params = () if filters: @@ -434,7 +456,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + sql_insert = SQL(INSERT_STATEMENT).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -543,8 +567,10 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( - table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + document_ids_str=SQL(document_ids_str), ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") @@ -570,6 +596,7 @@ def _keyword_retrieval( raise ValueError(msg) sql_select = SQL(KEYWORD_QUERY).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), query=SQLLiteral(query), @@ -643,7 +670,8 @@ def _embedding_retrieval( elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" - sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), score=SQL(score_definition), ) diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 93514b71c..4af4fc8de 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -47,6 +47,7 @@ def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + schema_name="my_schema", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -59,6 +60,7 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -93,6 +95,7 @@ def test_to_dict(monkeypatch): "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", + "schema_name": "public", "embedding_dimension": 512, "vector_function": "l2_distance", "recreate_table": True, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 290891307..4125c3e3a 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -175,6 +176,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", From e228e0713fd5dcc1b7e5799853f9a8ab66c3288d Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 14:03:04 +0100 Subject: [PATCH 14/16] Add AnthropicVertexChatGenerator component (#1192) * Created a model adapter * Create adapter class and add VertexAPI * Add chat generator for Anthropic Vertex * Add tests * Small fix * Improve doc_strings * Make project_id and region mandatory params * Small fix --- .../generators/anthropic/__init__.py | 3 +- .../anthropic/chat/vertex_chat_generator.py | 135 ++++++++++++ .../tests/test_vertex_chat_generator.py | 197 ++++++++++++++++++ 3 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py create mode 100644 integrations/anthropic/tests/test_vertex_chat_generator.py diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..0bd29898e 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. From 616063963a916f5f41c72ec02eed4b2d3939eb9a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 15:23:30 +0100 Subject: [PATCH 15/16] Enable kwargs in SearchIndex and Embedding Retriever (#1185) * Enable kwargs for semantic ranking --- .../azure_ai_search/example/document_store.py | 3 +- .../example/embedding_retrieval.py | 5 +- .../azure_ai_search/embedding_retriever.py | 41 +++++++++------ .../azure_ai_search/__init__.py | 4 +- .../azure_ai_search/document_store.py | 50 +++++++++++-------- .../azure_ai_search/filters.py | 2 +- .../azure_ai_search/tests/conftest.py | 22 ++++++-- 7 files changed, 78 insertions(+), 49 deletions(-) diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index 779f28935..92a641717 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -1,5 +1,4 @@ from haystack import Document -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -30,7 +29,7 @@ meta={"version": 2.0, "label": "chapter_three"}, ), ] -document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) +document_store.write_documents(documents) filters = { "operator": "AND", diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py index 088b08653..188f8525a 100644 --- a/integrations/azure_ai_search/example/embedding_retrieval.py +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -1,7 +1,6 @@ from haystack import Document, Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.writers import DocumentWriter -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -38,9 +37,7 @@ # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": documents}}) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index ab649f874..af48b74fb 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,16 +25,23 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchEmbeddingRetriever component. :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. :param filters: Filters applied when fetching documents from the Document Store. - Filters are applied during the approximate kNN search to ensure the Retriever returns - `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ self._filters = filters or {} @@ -43,6 +50,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AzureAISearchDocumentStore. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding. :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: - - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index 635878a38..ca0ea7554 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore -from .filters import normalize_filters +from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 68a2db22e..00682baba 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -31,7 +31,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from .errors import AzureAISearchDocumentStoreConfigError -from .filters import normalize_filters +from .filters import _normalize_filters type_mapping = { str: "Edm.String", @@ -70,7 +70,7 @@ def __init__( embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, - **kwargs, + **index_creation_kwargs, ): """ A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) @@ -87,19 +87,20 @@ def __init__( :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. - :param kwargs: Optional keyword parameters for Azure AI Search. - Some of the supported parameters: - - `api_version`: The Search API version to use for requests. - - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). - The audience is not considered when using a shared key. If audience is not provided, - the public cloud audience will be assumed. + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. - For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: - msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." raise ValueError(msg) api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None @@ -114,7 +115,7 @@ def __init__( self._dummy_vector = [-10.0] * self._embedding_dimension self._metadata_fields = metadata_fields self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH - self._kwargs = kwargs + self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: @@ -128,7 +129,10 @@ def client(self) -> SearchClient: credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) if not self._index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -151,7 +155,7 @@ def client(self) -> SearchClient: return self._client - def _create_index(self, index_name: str, **kwargs) -> None: + def _create_index(self, index_name: str) -> None: """ Creates a new search index. :param index_name: Name of the index to create. If None, the index name from the constructor is used. @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None: if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) index = SearchIndex( - name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, ) if self._index_client: self._index_client.create_index(index) @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, - api_key=self._api_key.to_dict() if self._api_key is not None else None, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, vector_search_configuration=self._vector_search_configuration.as_dict(), - **self._kwargs, + **self._index_creation_kwargs, ) @classmethod @@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters: - normalized_filters = normalize_filters(filters) + normalized_filters = _normalize_filters(filters) result = self.client.search(filter=normalized_filters) return self._convert_search_result_to_documents(result) else: @@ -409,8 +416,8 @@ def _embedding_retrieval( query_embedding: List[float], *, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -422,9 +429,10 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. + :param top_k: Maximum number of Documents to return, defaults to 10. :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` @@ -435,7 +443,7 @@ def _embedding_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py index 650e3f8be..0f105bc91 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -7,7 +7,7 @@ LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} -def normalize_filters(filters: Dict[str, Any]) -> str: +def _normalize_filters(filters: Dict[str, Any]) -> str: """ Converts Haystack filters in Azure AI Search compatible filters. """ diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 3017c79c2..89369c87e 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,12 +6,14 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient +from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore # This is the approximate time in seconds it takes for the documents to be available in Azure Search index -SLEEP_TIME_IN_SECONDS = 5 +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 @pytest.fixture() @@ -46,23 +48,35 @@ def document_store(request): # Override some methods to wait for the documents to be available original_write_documents = store.write_documents + original_delete_documents = store.delete_documents def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): written_docs = original_write_documents(documents, policy) time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs - original_delete_documents = store.delete_documents - def delete_documents_and_wait(filters): original_delete_documents(filters) time.sleep(SLEEP_TIME_IN_SECONDS) + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - pass + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}") From 4bfebacce25bf5e214c98af4bb8a84c7eeec0f23 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 16:45:00 +0100 Subject: [PATCH 16/16] Enable kwargs for semantic ranking in retrievers --- .../azure_ai_search/bm25_retriever.py | 26 ++++++++-- .../azure_ai_search/hybrid_retriever.py | 37 +++++++++----- .../azure_ai_search/document_store.py | 48 ++++++++++++------- .../tests/test_bm25_retriever.py | 7 +-- 4 files changed, 80 insertions(+), 38 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index 65e273b73..476144545 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchBM25Retriever component. @@ -34,7 +35,16 @@ def __init__( Filters are applied during the BM25 search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + """ self._filters = filters or {} @@ -43,7 +53,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) - + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" raise Exception(message) @@ -61,6 +71,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -100,7 +111,7 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" @@ -109,8 +120,13 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio query=query, filters=normalized_filters, top_k=top_k, + **self._kwargs, ) except Exception as e: - raise e + msg = ( + "An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index fbe7752aa..ce0a17e2e 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchHybridRetriever component. @@ -34,7 +35,16 @@ def __init__( Filters are applied during the hybrid search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + """ self._filters = filters or {} @@ -43,6 +53,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +72,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -95,30 +107,31 @@ def run( """Retrieve documents from the AzureAISearchDocumentStore. :param query: Text of the query. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: + :param top_k: The maximum number of documents to retrieve. + :returns: A dictionary with the following keys: - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._hybrid_retrieval( - query=query, - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query and query_embedding are valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 00682baba..cf0657495 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -421,7 +421,7 @@ def _embedding_retrieval( ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. - It uses the vector configuration of the document store. By default it uses the HNSW algorithm + It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm with cosine similarity. This method is not meant to be part of the public interface of @@ -429,13 +429,12 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. - :param top_k: Maximum number of Documents to return, defaults to 10. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query_embedding` is an empty list - :returns: List of Document that are most similar to `query_embedding` + :raises ValueError: If `query_embedding` is an empty list. + :returns: List of Document that are most similar to `query_embedding`. """ if not query_embedding: @@ -451,30 +450,31 @@ def _bm25_retrieval( self, query: str, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ - Retrieves documents that are most similar to `query`, using the BM25 algorithm + Retrieves documents that are most similar to `query`, using the BM25 algorithm. This method is not meant to be part of the public interface of `AzureAISearchDocumentStore` nor called directly. `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. :param query: Text of the query. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query` is an empty string - :returns: List of Document that are most similar to `query` + + :raises ValueError: If `query` is an empty string. + :returns: List of Document that are most similar to `query`. """ if query is None: msg = "query must not be None" raise ValueError(msg) - result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple") + result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) @@ -483,9 +483,25 @@ def _hybrid_retrieval( query: str, query_embedding: List[float], top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: + """ + Retrieves documents similar to query using the vector configuration in the document store and + the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query` or `query_embedding` is empty. + :returns: List of Document that are most similar to `query`. + """ if query is None: msg = "query must not be None" @@ -498,10 +514,10 @@ def _hybrid_retrieval( result = self.client.search( search_text=query, vector_queries=[vector_query], - select=fields, filter=filters, top=top_k, query_type="simple", + **kwargs, ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py index d0c6d0da9..e6631a16b 100644 --- a/integrations/azure_ai_search/tests/test_bm25_retriever.py +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -2,14 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from typing import List from unittest.mock import Mock import pytest -from azure.core.exceptions import HttpResponseError from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from numpy.random import rand # type: ignore from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -35,7 +32,7 @@ def test_to_dict(): retriever = AzureAISearchBM25Retriever(document_store=document_store) res = retriever.to_dict() assert res == { - "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", "init_parameters": { "filters": {}, "top_k": 10, @@ -73,7 +70,7 @@ def test_to_dict(): def test_from_dict(): data = { - "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", "init_parameters": { "filters": {}, "top_k": 10,