From 67e08d0b7e5a7f51f52bb0d40fe40b0ff2caf43a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 15:23:30 +0100 Subject: [PATCH] Enable kwargs in SearchIndex and Embedding Retriever (#1185) * Enable kwargs for semantic ranking --- .../azure_ai_search/example/document_store.py | 3 +- .../example/embedding_retrieval.py | 5 +- .../azure_ai_search/embedding_retriever.py | 41 +++++++++------ .../azure_ai_search/__init__.py | 4 +- .../azure_ai_search/document_store.py | 50 +++++++++++-------- .../azure_ai_search/filters.py | 2 +- .../azure_ai_search/tests/conftest.py | 22 ++++++-- 7 files changed, 78 insertions(+), 49 deletions(-) diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index 779f28935..92a641717 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -1,5 +1,4 @@ from haystack import Document -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -30,7 +29,7 @@ meta={"version": 2.0, "label": "chapter_three"}, ), ] -document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) +document_store.write_documents(documents) filters = { "operator": "AND", diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py index 088b08653..188f8525a 100644 --- a/integrations/azure_ai_search/example/embedding_retrieval.py +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -1,7 +1,6 @@ from haystack import Document, Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.writers import DocumentWriter -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -38,9 +37,7 @@ # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": documents}}) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index ab649f874..af48b74fb 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,16 +25,23 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchEmbeddingRetriever component. :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. :param filters: Filters applied when fetching documents from the Document Store. - Filters are applied during the approximate kNN search to ensure the Retriever returns - `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ self._filters = filters or {} @@ -43,6 +50,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AzureAISearchDocumentStore. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding. :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: - - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index 635878a38..ca0ea7554 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore -from .filters import normalize_filters +from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..74260b4fa 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -31,7 +31,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from .errors import AzureAISearchDocumentStoreConfigError -from .filters import normalize_filters +from .filters import _normalize_filters type_mapping = { str: "Edm.String", @@ -70,7 +70,7 @@ def __init__( embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, - **kwargs, + **index_creation_kwargs, ): """ A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) @@ -87,19 +87,20 @@ def __init__( :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. - :param kwargs: Optional keyword parameters for Azure AI Search. - Some of the supported parameters: - - `api_version`: The Search API version to use for requests. - - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). - The audience is not considered when using a shared key. If audience is not provided, - the public cloud audience will be assumed. + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. - For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: - msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." raise ValueError(msg) api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None @@ -114,7 +115,7 @@ def __init__( self._dummy_vector = [-10.0] * self._embedding_dimension self._metadata_fields = metadata_fields self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH - self._kwargs = kwargs + self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: @@ -128,7 +129,10 @@ def client(self) -> SearchClient: credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) if not self._index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -151,7 +155,7 @@ def client(self) -> SearchClient: return self._client - def _create_index(self, index_name: str, **kwargs) -> None: + def _create_index(self, index_name: str) -> None: """ Creates a new search index. :param index_name: Name of the index to create. If None, the index name from the constructor is used. @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None: if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) index = SearchIndex( - name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, ) if self._index_client: self._index_client.create_index(index) @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, - api_key=self._api_key.to_dict() if self._api_key is not None else None, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, vector_search_configuration=self._vector_search_configuration.as_dict(), - **self._kwargs, + **self._index_creation_kwargs, ) @classmethod @@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters: - normalized_filters = normalize_filters(filters) + normalized_filters = _normalize_filters(filters) result = self.client.search(filter=normalized_filters) return self._convert_search_result_to_documents(result) else: @@ -409,8 +416,8 @@ def _embedding_retrieval( query_embedding: List[float], *, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -422,9 +429,10 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. + :param top_k: Maximum number of Documents to return, defaults to 10. :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` @@ -435,6 +443,6 @@ def _embedding_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py index 650e3f8be..0f105bc91 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -7,7 +7,7 @@ LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} -def normalize_filters(filters: Dict[str, Any]) -> str: +def _normalize_filters(filters: Dict[str, Any]) -> str: """ Converts Haystack filters in Azure AI Search compatible filters. """ diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 3017c79c2..89369c87e 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,12 +6,14 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient +from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore # This is the approximate time in seconds it takes for the documents to be available in Azure Search index -SLEEP_TIME_IN_SECONDS = 5 +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 @pytest.fixture() @@ -46,23 +48,35 @@ def document_store(request): # Override some methods to wait for the documents to be available original_write_documents = store.write_documents + original_delete_documents = store.delete_documents def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): written_docs = original_write_documents(documents, policy) time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs - original_delete_documents = store.delete_documents - def delete_documents_and_wait(filters): original_delete_documents(filters) time.sleep(SLEEP_TIME_IN_SECONDS) + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - pass + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}")