Skip to content

Commit

Permalink
Enable kwargs in SearchIndex and Embedding Retriever (#1185)
Browse files Browse the repository at this point in the history
* Enable kwargs for semantic ranking
  • Loading branch information
Amnah199 authored Nov 15, 2024
1 parent e21ce0c commit 67e08d0
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 49 deletions.
3 changes: 1 addition & 2 deletions integrations/azure_ai_search/example/document_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from haystack import Document
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore

Expand Down Expand Up @@ -30,7 +29,7 @@
meta={"version": 2.0, "label": "chapter_three"},
),
]
document_store.write_documents(documents, policy=DuplicatePolicy.SKIP)
document_store.write_documents(documents)

filters = {
"operator": "AND",
Expand Down
5 changes: 1 addition & 4 deletions integrations/azure_ai_search/example/embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.writers import DocumentWriter
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore
Expand Down Expand Up @@ -38,9 +37,7 @@
# Indexing Pipeline
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder")
indexing_pipeline.add_component(
instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer")
indexing_pipeline.connect("doc_embedder", "doc_writer")

indexing_pipeline.run({"doc_embedder": {"documents": documents}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters

logger = logging.getLogger(__name__)

Expand All @@ -25,16 +25,23 @@ def __init__(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
**kwargs,
):
"""
Create the AzureAISearchEmbeddingRetriever component.
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the approximate kNN search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied. Possible options:
:param filter_policy: Policy to determine how filters are applied.
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
Some of the supported parameters:
- `query_type`: A string indicating the type of query to perform. Possible values are
'simple','full' and 'semantic'.
- `semantic_configuration_name`: The name of semantic configuration to be used when
processing semantic queries.
For more information on parameters, see the
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
"""
self._filters = filters or {}
Expand All @@ -43,6 +50,7 @@ def __init__(
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
Expand All @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
**self._kwargs,
)

@classmethod
Expand Down Expand Up @@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever":
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Retrieve documents from the AzureAISearchDocumentStore.
:param query_embedding: floats representing the query embedding
:param query_embedding: A list of floats representing the query embedding.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more
details.
:param top_k: The maximum number of documents to retrieve.
:returns: Dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
if filters is not None:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = normalize_filters(applied_filters)
normalized_filters = _normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=normalized_filters,
top_k=top_k,
query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs
)
except Exception as e:
raise e
msg = (
"An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. "
"Ensure that the query embedding is valid and the document store is correctly configured."
)
raise RuntimeError(msg) from e

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore
from .filters import normalize_filters
from .filters import _normalize_filters

__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"]
__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace

from .errors import AzureAISearchDocumentStoreConfigError
from .filters import normalize_filters
from .filters import _normalize_filters

type_mapping = {
str: "Edm.String",
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
embedding_dimension: int = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
**kwargs,
**index_creation_kwargs,
):
"""
A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/)
Expand All @@ -87,19 +87,20 @@ def __init__(
:param vector_search_configuration: Configuration option related to vector search.
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.
:param kwargs: Optional keyword parameters for Azure AI Search.
Some of the supported parameters:
- `api_version`: The Search API version to use for requests.
- `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD).
The audience is not considered when using a shared key. If audience is not provided,
the public cloud audience will be assumed.
:param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class
during index creation. Some of the supported parameters:
- `semantic_search`: Defines semantic configuration of the search index. This parameter is needed
to enable semantic search capabilities in index.
- `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents
matching a search query. The similarity algorithm can only be defined at index creation time and
cannot be modified on existing indexes.
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/)
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
"""

azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None
if not azure_endpoint:
msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT."
raise ValueError(msg)

api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None
Expand All @@ -114,7 +115,7 @@ def __init__(
self._dummy_vector = [-10.0] * self._embedding_dimension
self._metadata_fields = metadata_fields
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
self._kwargs = kwargs
self._index_creation_kwargs = index_creation_kwargs

@property
def client(self) -> SearchClient:
Expand All @@ -128,7 +129,10 @@ def client(self) -> SearchClient:
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
try:
if not self._index_client:
self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs)
self._index_client = SearchIndexClient(
resolved_endpoint,
credential,
)
if not self._index_exists(self._index_name):
# Create a new index if it does not exist
logger.debug(
Expand All @@ -151,7 +155,7 @@ def client(self) -> SearchClient:

return self._client

def _create_index(self, index_name: str, **kwargs) -> None:
def _create_index(self, index_name: str) -> None:
"""
Creates a new search index.
:param index_name: Name of the index to create. If None, the index name from the constructor is used.
Expand All @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None:
if self._metadata_fields:
default_fields.extend(self._create_metadata_index_fields(self._metadata_fields))
index = SearchIndex(
name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs
name=index_name,
fields=default_fields,
vector_search=self._vector_search_configuration,
**self._index_creation_kwargs,
)
if self._index_client:
self._index_client.create_index(index)
Expand All @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None,
api_key=self._api_key.to_dict() if self._api_key is not None else None,
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None,
api_key=self._api_key.to_dict() if self._api_key else None,
index_name=self._index_name,
embedding_dimension=self._embedding_dimension,
metadata_fields=self._metadata_fields,
vector_search_configuration=self._vector_search_configuration.as_dict(),
**self._kwargs,
**self._index_creation_kwargs,
)

@classmethod
Expand Down Expand Up @@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:returns: A list of Documents that match the given filters.
"""
if filters:
normalized_filters = normalize_filters(filters)
normalized_filters = _normalize_filters(filters)
result = self.client.search(filter=normalized_filters)
return self._convert_search_result_to_documents(result)
else:
Expand Down Expand Up @@ -409,8 +416,8 @@ def _embedding_retrieval(
query_embedding: List[float],
*,
top_k: int = 10,
fields: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand All @@ -422,9 +429,10 @@ def _embedding_retrieval(
`AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it.
:param query_embedding: Embedding of the query.
:param top_k: Maximum number of Documents to return, defaults to 10.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint.
:raises ValueError: If `query_embedding` is an empty list
:returns: List of Document that are most similar to `query_embedding`
Expand All @@ -435,6 +443,6 @@ def _embedding_retrieval(
raise ValueError(msg)

vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding")
result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters)
result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"}


def normalize_filters(filters: Dict[str, Any]) -> str:
def _normalize_filters(filters: Dict[str, Any]) -> str:
"""
Converts Haystack filters in Azure AI Search compatible filters.
"""
Expand Down
22 changes: 18 additions & 4 deletions integrations/azure_ai_search/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.search.documents.indexes import SearchIndexClient
from haystack import logging
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore

# This is the approximate time in seconds it takes for the documents to be available in Azure Search index
SLEEP_TIME_IN_SECONDS = 5
SLEEP_TIME_IN_SECONDS = 10
MAX_WAIT_TIME_FOR_INDEX_DELETION = 5


@pytest.fixture()
Expand Down Expand Up @@ -46,23 +48,35 @@ def document_store(request):

# Override some methods to wait for the documents to be available
original_write_documents = store.write_documents
original_delete_documents = store.delete_documents

def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE):
written_docs = original_write_documents(documents, policy)
time.sleep(SLEEP_TIME_IN_SECONDS)
return written_docs

original_delete_documents = store.delete_documents

def delete_documents_and_wait(filters):
original_delete_documents(filters)
time.sleep(SLEEP_TIME_IN_SECONDS)

# Helper function to wait for the index to be deleted, needed to cover latency
def wait_for_index_deletion(client, index_name):
start_time = time.time()
while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION:
if index_name not in client.list_index_names():
return True
time.sleep(1)
return False

store.write_documents = write_documents_and_wait
store.delete_documents = delete_documents_and_wait

yield store
try:
client.delete_index(index_name)
if not wait_for_index_deletion(client, index_name):
logging.error(f"Index {index_name} was not properly deleted.")
except ResourceNotFoundError:
pass
logging.info(f"Index {index_name} was already deleted or not found.")
except Exception as e:
logging.error(f"Unexpected error when deleting index {index_name}: {e}")

0 comments on commit 67e08d0

Please sign in to comment.