Skip to content

Commit

Permalink
Enable kwargs for semantic ranking
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Nov 12, 2024
1 parent 4cfee2d commit 03055cc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 19 deletions.
3 changes: 1 addition & 2 deletions integrations/azure_ai_search/example/document_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from haystack import Document
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore

Expand Down Expand Up @@ -30,7 +29,7 @@
meta={"version": 2.0, "label": "chapter_three"},
),
]
document_store.write_documents(documents, policy=DuplicatePolicy.SKIP)
document_store.write_documents(documents)

filters = {
"operator": "AND",
Expand Down
5 changes: 1 addition & 4 deletions integrations/azure_ai_search/example/embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.writers import DocumentWriter
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore
Expand Down Expand Up @@ -38,9 +37,7 @@
# Indexing Pipeline
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder")
indexing_pipeline.add_component(
instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer"
)
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer")
indexing_pipeline.connect("doc_embedder", "doc_writer")

indexing_pipeline.run({"doc_embedder": {"documents": documents}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
**kwargs,
):
"""
Create the AzureAISearchEmbeddingRetriever component.
Expand All @@ -43,6 +44,7 @@ def __init__(
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
Expand Down Expand Up @@ -106,9 +108,7 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =

try:
docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=normalized_filters,
top_k=top_k,
query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs
)
except Exception as e:
raise e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,20 @@ def __init__(
:param vector_search_configuration: Configuration option related to vector search.
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.
:param kwargs: Optional keyword parameters for Azure AI Search.
:param kwargs: Optional keyword parameters to be passed to SearchIndex during index creation.
Some of the supported parameters:
- `api_version`: The Search API version to use for requests.
- `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD).
The audience is not considered when using a shared key. If audience is not provided,
the public cloud audience will be assumed.
- `semantic_search`: Defines semantic configuration of the search index. This parameter is needed
to enable semantic search capabilities in index.
- `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents
matching a search query. The similarity algorithm can only be defined at index creation time and
cannot be modified on existing indexes.
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/)
"""

azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None
if not azure_endpoint:
msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT."
raise ValueError(msg)

api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None
Expand Down Expand Up @@ -128,7 +129,10 @@ def client(self) -> SearchClient:
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
try:
if not self._index_client:
self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs)
self._index_client = SearchIndexClient(
resolved_endpoint,
credential,
)
if not self._index_exists(self._index_name):
# Create a new index if it does not exist
logger.debug(
Expand All @@ -151,7 +155,7 @@ def client(self) -> SearchClient:

return self._client

def _create_index(self, index_name: str, **kwargs) -> None:
def _create_index(self, index_name: str) -> None:
"""
Creates a new search index.
:param index_name: Name of the index to create. If None, the index name from the constructor is used.
Expand All @@ -177,7 +181,7 @@ def _create_index(self, index_name: str, **kwargs) -> None:
if self._metadata_fields:
default_fields.extend(self._create_metadata_index_fields(self._metadata_fields))
index = SearchIndex(
name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs
name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **self._kwargs
)
if self._index_client:
self._index_client.create_index(index)
Expand Down Expand Up @@ -411,6 +415,7 @@ def _embedding_retrieval(
top_k: int = 10,
fields: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand All @@ -435,6 +440,8 @@ def _embedding_retrieval(
raise ValueError(msg)

vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding")
result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters)
result = self.client.search(
search_text=None, vector_queries=[vector_query], select=fields, filter=filters, **kwargs
)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

0 comments on commit 03055cc

Please sign in to comment.