diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..0bd29898e 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index 779f28935..92a641717 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -1,5 +1,4 @@ from haystack import Document -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -30,7 +29,7 @@ meta={"version": 2.0, "label": "chapter_three"}, ), ] -document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) +document_store.write_documents(documents) filters = { "operator": "AND", diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py index 088b08653..188f8525a 100644 --- a/integrations/azure_ai_search/example/embedding_retrieval.py +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -1,7 +1,6 @@ from haystack import Document, Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.writers import DocumentWriter -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -38,9 +37,7 @@ # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": documents}}) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index ab649f874..af48b74fb 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,16 +25,23 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchEmbeddingRetriever component. :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. :param filters: Filters applied when fetching documents from the Document Store. - Filters are applied during the approximate kNN search to ensure the Retriever returns - `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ self._filters = filters or {} @@ -43,6 +50,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AzureAISearchDocumentStore. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding. :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: - - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index 635878a38..ca0ea7554 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore -from .filters import normalize_filters +from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..74260b4fa 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -31,7 +31,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from .errors import AzureAISearchDocumentStoreConfigError -from .filters import normalize_filters +from .filters import _normalize_filters type_mapping = { str: "Edm.String", @@ -70,7 +70,7 @@ def __init__( embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, - **kwargs, + **index_creation_kwargs, ): """ A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) @@ -87,19 +87,20 @@ def __init__( :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. - :param kwargs: Optional keyword parameters for Azure AI Search. - Some of the supported parameters: - - `api_version`: The Search API version to use for requests. - - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). - The audience is not considered when using a shared key. If audience is not provided, - the public cloud audience will be assumed. + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. - For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: - msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." raise ValueError(msg) api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None @@ -114,7 +115,7 @@ def __init__( self._dummy_vector = [-10.0] * self._embedding_dimension self._metadata_fields = metadata_fields self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH - self._kwargs = kwargs + self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: @@ -128,7 +129,10 @@ def client(self) -> SearchClient: credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) if not self._index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -151,7 +155,7 @@ def client(self) -> SearchClient: return self._client - def _create_index(self, index_name: str, **kwargs) -> None: + def _create_index(self, index_name: str) -> None: """ Creates a new search index. :param index_name: Name of the index to create. If None, the index name from the constructor is used. @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None: if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) index = SearchIndex( - name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, ) if self._index_client: self._index_client.create_index(index) @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, - api_key=self._api_key.to_dict() if self._api_key is not None else None, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, vector_search_configuration=self._vector_search_configuration.as_dict(), - **self._kwargs, + **self._index_creation_kwargs, ) @classmethod @@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters: - normalized_filters = normalize_filters(filters) + normalized_filters = _normalize_filters(filters) result = self.client.search(filter=normalized_filters) return self._convert_search_result_to_documents(result) else: @@ -409,8 +416,8 @@ def _embedding_retrieval( query_embedding: List[float], *, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -422,9 +429,10 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. + :param top_k: Maximum number of Documents to return, defaults to 10. :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` @@ -435,6 +443,6 @@ def _embedding_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py index 650e3f8be..0f105bc91 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -7,7 +7,7 @@ LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} -def normalize_filters(filters: Dict[str, Any]) -> str: +def _normalize_filters(filters: Dict[str, Any]) -> str: """ Converts Haystack filters in Azure AI Search compatible filters. """ diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 3017c79c2..89369c87e 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,12 +6,14 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient +from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore # This is the approximate time in seconds it takes for the documents to be available in Azure Search index -SLEEP_TIME_IN_SECONDS = 5 +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 @pytest.fixture() @@ -46,23 +48,35 @@ def document_store(request): # Override some methods to wait for the documents to be available original_write_documents = store.write_documents + original_delete_documents = store.delete_documents def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): written_docs = original_write_documents(documents, policy) time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs - original_delete_documents = store.delete_documents - def delete_documents_and_wait(filters): original_delete_documents(filters) time.sleep(SLEEP_TIME_IN_SECONDS) + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - pass + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}")