diff --git a/docs/pydoc/config/document_stores_api.yml b/docs/pydoc/config/document_stores_api.yml new file mode 100644 index 00000000..009add33 --- /dev/null +++ b/docs/pydoc/config/document_stores_api.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../] + modules: + [ + "haystack_experimental.document_stores.in_memory.document_store", + "haystack_experimental.document_stores.opensearch.document_store", + "haystack_experimental.document_stores.types.protocol", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer + excerpt: Stores your texts and meta data and provides them to the Retriever at query time. + category_slug: experiments-api + title: Document Stores + slug: experimental-document-stores-api + order: 160 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: document_stores_api.md diff --git a/docs/pydoc/config/retrievers_api.yml b/docs/pydoc/config/retrievers_api.yml index 4f16f6aa..1a449671 100644 --- a/docs/pydoc/config/retrievers_api.yml +++ b/docs/pydoc/config/retrievers_api.yml @@ -4,7 +4,9 @@ loaders: modules: [ "haystack_experimental.components.retrievers.auto_merging_retriever", - "haystack_experimental.components.retrievers.chat_message_retriever" + "haystack_experimental.components.retrievers.chat_message_retriever", + "haystack_experimental.components.retrievers.opensearch.bm25_retriever", + "haystack_experimental.components.retrievers.opensearch.embedding_retriever", ] ignore_when_discovered: ["__init__"] processors: @@ -28,4 +30,4 @@ renderer: descriptive_module_title: true add_method_class_prefix: true add_member_class_prefix: false - filename: experimental_retrievers_api.md \ No newline at end of file + filename: experimental_retrievers_api.md diff --git a/docs/pydoc/config/writers_api.yml b/docs/pydoc/config/writers_api.yml index 23e25d01..fba6846d 100644 --- a/docs/pydoc/config/writers_api.yml +++ b/docs/pydoc/config/writers_api.yml @@ -1,7 +1,11 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../] - modules: ["haystack_experimental.components.writers.chat_message_writer"] + modules: + [ + "haystack_experimental.components.writers.chat_message_writer", + "haystack_experimental.components.writers.document_writer", + ] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack_experimental/components/generators/chat/openai.py b/haystack_experimental/components/generators/chat/openai.py index 7517d508..19337b46 100644 --- a/haystack_experimental/components/generators/chat/openai.py +++ b/haystack_experimental/components/generators/chat/openai.py @@ -3,19 +3,29 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Any, Callable, Dict, List, Optional, Union +import os +from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, logging -from haystack.components.generators.chat.openai import OpenAIChatGenerator as OpenAIChatGeneratorBase +from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import StreamingChunk -from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace -from openai import Stream +from haystack.utils import ( + Secret, + deserialize_callable, + deserialize_secrets_inplace, + serialize_callable, +) +from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice -from haystack_experimental.dataclasses import ChatMessage, ToolCall -from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace +from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall +from haystack_experimental.dataclasses.streaming_chunk import ( + AsyncStreamingCallbackT, + StreamingCallbackT, + select_streaming_callback, +) +from haystack_experimental.dataclasses.tool import deserialize_tools_inplace logger = logging.getLogger(__name__) @@ -29,16 +39,22 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: tool_call_results = message.tool_call_results if not text_contents and not tool_calls and not tool_call_results: - raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + raise ValueError( + "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + ) elif len(text_contents) + len(tool_call_results) > 1: - raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + raise ValueError( + "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`." + ) openai_msg: Dict[str, Any] = {"role": message._role.value} if tool_call_results: result = tool_call_results[0] if result.origin.id is None: - raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + raise ValueError( + "`ToolCall` must have a non-null `id` attribute to be used with OpenAI." + ) openai_msg["content"] = result.result openai_msg["tool_call_id"] = result.origin.id # OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field @@ -50,13 +66,18 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: openai_tool_calls = [] for tc in tool_calls: if tc.id is None: - raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.") + raise ValueError( + "`ToolCall` must have a non-null `id` attribute to be used with OpenAI." + ) openai_tool_calls.append( { "id": tc.id, "type": "function", # We disable ensure_ascii so special chars like emojis are not converted - "function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)}, + "function": { + "name": tc.tool_name, + "arguments": json.dumps(tc.arguments, ensure_ascii=False), + }, } ) openai_msg["tool_calls"] = openai_tool_calls @@ -64,7 +85,7 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: @component -class OpenAIChatGenerator(OpenAIChatGeneratorBase): +class OpenAIChatGenerator: """ Completes chats using OpenAI's large language models (LLMs). @@ -109,7 +130,9 @@ def __init__( # noqa: PLR0913 self, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), model: str = "gpt-4o-mini", - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[ + Union[StreamingCallbackT, AsyncStreamingCallbackT] + ] = None, api_base_url: Optional[str] = None, organization: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, @@ -119,25 +142,23 @@ def __init__( # noqa: PLR0913 tools_strict: bool = False, ): """ - Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's GPT-3.5. - - Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' - environment variables to override the `timeout` and `max_retries` parameters respectively - in the OpenAI client. - - :param api_key: The OpenAI API key. - You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter - during initialization. - :param model: The name of the model to use. - :param streaming_callback: A callback function that is called when a new token is received from the stream. + Creates an instance of OpenAIChatGenerator. + + :param api_key: + The OpenAI API key. + :param model: + The name of the model to use. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. The callback function accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) - as an argument. - :param api_base_url: An optional base URL. - :param organization: Your organization ID, defaults to `None`. See - [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). - :param generation_kwargs: Other parameters to use for the model. These parameters are sent directly to - the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for - more details. + as an argument. Must be a coroutine if the component is used in an async pipeline. + :param api_base_url: + An optional base URL. + :param organization: + Your organization ID. See [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). + :param generation_kwargs: + Other parameters to use for the model. These parameters are sent directly to the OpenAI endpoint. + See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported parameters: - `max_tokens`: The maximum number of tokens the output text can have. - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. @@ -166,21 +187,35 @@ def __init__( # noqa: PLR0913 Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. """ - if tools: - tool_names = [tool.name for tool in tools] - duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} - if duplicate_tool_names: - raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + self.api_key = api_key + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.api_base_url = api_base_url + self.organization = organization + self.timeout = timeout + self.max_retries = max_retries self.tools = tools self.tools_strict = tools_strict - super(OpenAIChatGenerator, self).__init__( - api_key=api_key, - model=model, - streaming_callback=streaming_callback, - api_base_url=api_base_url, + self._validate_tools(tools) + + if timeout is None: + timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0)) + if max_retries is None: + max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + + self.client = OpenAI( + api_key=api_key.resolve_value(), + organization=organization, + base_url=api_base_url, + timeout=timeout, + max_retries=max_retries, + ) + self.async_client = AsyncOpenAI( + api_key=api_key.resolve_value(), organization=organization, - generation_kwargs=generation_kwargs, + base_url=api_base_url, timeout=timeout, max_retries=max_retries, ) @@ -192,17 +227,32 @@ def to_dict(self) -> Dict[str, Any]: :returns: The serialized component as a dictionary. """ - serialized = super(OpenAIChatGenerator, self).to_dict() - serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None - serialized["init_parameters"]["tools_strict"] = self.tools_strict - return serialized + callback_name = ( + serialize_callable(self.streaming_callback) + if self.streaming_callback + else None + ) + return default_to_dict( + self, + model=self.model, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + organization=self.organization, + generation_kwargs=self.generation_kwargs, + api_key=self.api_key.to_dict(), + timeout=self.timeout, + max_retries=self.max_retries, + tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools_strict=self.tools_strict, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": """ Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. + :param data: + The dictionary representation of this component. :returns: The deserialized component instance. """ @@ -211,7 +261,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + data["init_parameters"]["streaming_callback"] = deserialize_callable( + serialized_callback_handler + ) return default_from_dict(cls, data) @@ -219,7 +271,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": def run( # noqa: PLR0913 self, messages: List[ChatMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[ + Union[StreamingCallbackT, AsyncStreamingCallbackT] + ] = None, generation_kwargs: Optional[Dict[str, Any]] = None, tools: Optional[List[Tool]] = None, tools_strict: Optional[bool] = None, @@ -227,12 +281,15 @@ def run( # noqa: PLR0913 """ Invokes chat completion based on the provided messages and generation parameters. - :param messages: A list of ChatMessage instances representing the input messages. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - override the parameters passed during component initialization. - For details on OpenAI API parameters, see - [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). + :param messages: + A list of ChatMessage instances representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + Cannot be a coroutine. + :param generation_kwargs: + Additional keyword arguments for text generation. These parameters will + override the parameters passed during component initialization. + For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). :param tools: A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set during component initialization. @@ -244,55 +301,103 @@ def run( # noqa: PLR0913 :returns: A list containing the generated responses as ChatMessage instances. """ + # validate and select the streaming callback + streaming_callback = select_streaming_callback( + self.streaming_callback, streaming_callback, requires_async=False + ) # type: ignore - # update generation kwargs by merging with the generation kwargs passed to the run method - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + if len(messages) == 0: + return {"replies": []} - # check if streaming_callback is passed - streaming_callback = streaming_callback or self.streaming_callback + api_args = self._prepare_api_call( + messages, streaming_callback, generation_kwargs, tools, tools_strict + ) + chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = ( + self.client.chat.completions.create(**api_args) + ) - # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages] + is_streaming = isinstance(chat_completion, Stream) + assert is_streaming or streaming_callback is None - tools = tools or self.tools - if tools: - tool_names = [tool.name for tool in tools] - duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} - if duplicate_tool_names: - raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + if is_streaming: + completions = self._handle_stream_response( + chat_completion, # type: ignore + streaming_callback, # type: ignore + ) + else: + assert isinstance( + chat_completion, ChatCompletion + ), "Unexpected response type for non-streaming request." + completions = [ + self._convert_chat_completion_to_chat_message(chat_completion, choice) + for choice in chat_completion.choices + ] - tools_strict = tools_strict if tools_strict is not None else self.tools_strict + # before returning, do post-processing of the completions + for message in completions: + self._check_finish_reason(message.meta) - openai_tools = None - if tools: - openai_tools = [{"type": "function", "function": {**t.tool_spec, "strict": tools_strict}} for t in tools] + return {"replies": completions} - chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( - model=self.model, - messages=openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types - stream=streaming_callback is not None, - tools=openai_tools, # type: ignore[arg-type] - **generation_kwargs, + @component.output_types(replies=List[ChatMessage]) + async def run_async( # noqa: PLR0913 + self, + messages: List[ChatMessage], + streaming_callback: Optional[ + Union[StreamingCallbackT, AsyncStreamingCallbackT] + ] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + tools_strict: Optional[bool] = None, + ): + """ + Invokes chat completion based on the provided messages and generation parameters. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + Must be a coroutine. + :param generation_kwargs: + Additional keyword arguments for text generation. These parameters will + override the parameters passed during component initialization. + For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly + the schema provided in the `parameters` field of the tool definition, but this may increase latency. + If set, it will override the `tools_strict` parameter set during component initialization. + + :returns: + A list containing the generated responses as ChatMessage instances. + """ + # validate and select the streaming callback + streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore + + if len(messages) == 0: + return {"replies": []} + + api_args = self._prepare_api_call( + messages, streaming_callback, generation_kwargs, tools, tools_strict + ) + chat_completion: Union[AsyncStream[ChatCompletionChunk], ChatCompletion] = ( + await self.async_client.chat.completions.create(**api_args) ) - completions: List[ChatMessage] = [] - # if streaming is enabled, the completion is a Stream of ChatCompletionChunk - if isinstance(chat_completion, Stream): - num_responses = generation_kwargs.pop("n", 1) - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - chunks: List[StreamingChunk] = [] - chunk = None - - # pylint: disable=not-an-iterable - for chunk in chat_completion: - if chunk.choices and streaming_callback: - chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) - chunks.append(chunk_delta) - streaming_callback(chunk_delta) # invoke callback with the chunk_delta - completions = [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] - # if streaming is disabled, the completion is a ChatCompletion - elif isinstance(chat_completion, ChatCompletion): + is_streaming = isinstance(chat_completion, AsyncStream) + assert is_streaming or streaming_callback is None + + if is_streaming: + completions = await self._handle_async_stream_response( + chat_completion, # type: ignore + streaming_callback, # type: ignore + ) + else: + assert isinstance( + chat_completion, ChatCompletion + ), "Unexpected response type for non-streaming request." completions = [ self._convert_chat_completion_to_chat_message(chat_completion, choice) for choice in chat_completion.choices @@ -300,11 +405,127 @@ def run( # noqa: PLR0913 # before returning, do post-processing of the completions for message in completions: - self._check_finish_reason(message) + self._check_finish_reason(message.meta) return {"replies": completions} - def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: + def _validate_tools(self, tools: Optional[List[Tool]]): + if tools is None: + return + + tool_names = [tool.name for tool in tools] + duplicate_tool_names = { + name for name in tool_names if tool_names.count(name) > 1 + } + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + def _prepare_api_call( # noqa: PLR0913 + self, + messages: List[ChatMessage], + streaming_callback: Optional[ + Union[StreamingCallbackT, AsyncStreamingCallbackT] + ], + generation_kwargs: Optional[Dict[str, Any]], + tools: Optional[List[Tool]], + tools_strict: Optional[bool], + ) -> Dict[str, Any]: + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + # adapt ChatMessage(s) to the format expected by the OpenAI API + openai_formatted_messages = [ + _convert_message_to_openai_format(message) for message in messages + ] + + tools = tools or self.tools + tools_strict = tools_strict if tools_strict is not None else self.tools_strict + self._validate_tools(tools) + + openai_tools = None + if tools: + openai_tools = [ + { + "type": "function", + "function": {**t.tool_spec, "strict": tools_strict}, + } + for t in tools + ] + + is_streaming = streaming_callback is not None + num_responses = generation_kwargs.pop("n", 1) + if is_streaming and num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + + return { + "model": self.model, + "messages": openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types + "stream": streaming_callback is not None, + "tools": openai_tools, # type: ignore[arg-type] + "n": num_responses, + **generation_kwargs, + } + + def _handle_stream_response( + self, + chat_completion: Stream, + callback: StreamingCallbackT, + ) -> List[ChatMessage]: + chunks: List[StreamingChunk] = [] + chunk = None + + for chunk in chat_completion: # pylint: disable=not-an-iterable + assert ( + len(chunk.choices) == 1 + ), "Streaming responses should have only one choice." + chunk_delta: StreamingChunk = ( + self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + ) + chunks.append(chunk_delta) + + callback(chunk_delta) + + return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + + async def _handle_async_stream_response( + self, + chat_completion: AsyncStream, + callback: AsyncStreamingCallbackT, + ) -> List[ChatMessage]: + chunks: List[StreamingChunk] = [] + chunk = None + + async for chunk in chat_completion: # pylint: disable=not-an-iterable + assert ( + len(chunk.choices) == 1 + ), "Streaming responses should have only one choice." + chunk_delta: StreamingChunk = ( + self._convert_chat_completion_chunk_to_streaming_chunk(chunk) + ) + chunks.append(chunk_delta) + + await callback(chunk_delta) + + return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + + def _check_finish_reason(self, meta: Dict[str, Any]) -> None: + if meta["finish_reason"] == "length": + logger.warning( + "The completion for index {index} has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + if meta["finish_reason"] == "content_filter": + logger.warning( + "The completion for index {index} has been truncated due to the content filter.", + index=meta["index"], + finish_reason=meta["finish_reason"], + ) + + def _convert_streaming_chunks_to_chat_message( + self, chunk: Any, chunks: List[StreamingChunk] + ) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. @@ -334,7 +555,13 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str arguments_str = payload["arguments"] try: arguments = json.loads(arguments_str) - tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments)) + tool_calls.append( + ToolCall( + id=payload["id"], + tool_name=payload["name"], + arguments=arguments, + ) + ) except json.JSONDecodeError: logger.warning( "OpenAI returned a malformed JSON string for tool call arguments. This tool call " @@ -354,7 +581,9 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) - def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: + def _convert_chat_completion_to_chat_message( + self, completion: ChatCompletion, choice: Choice + ) -> ChatMessage: """ Converts the non-streaming response from the OpenAI API to a ChatMessage. @@ -370,7 +599,13 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c arguments_str = openai_tc.function.arguments try: arguments = json.loads(arguments_str) - tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments)) + tool_calls.append( + ToolCall( + id=openai_tc.id, + tool_name=openai_tc.function.name, + arguments=arguments, + ) + ) except json.JSONDecodeError: logger.warning( "OpenAI returned a malformed JSON string for tool call arguments. This tool call " @@ -392,7 +627,9 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c ) return chat_message - def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: + def _convert_chat_completion_chunk_to_streaming_chunk( + self, chunk: ChatCompletionChunk + ) -> StreamingChunk: """ Converts the streaming response chunk from the OpenAI API to a StreamingChunk. diff --git a/haystack_experimental/components/retrievers/opensearch/__init__.py b/haystack_experimental/components/retrievers/opensearch/__init__.py new file mode 100644 index 00000000..6a5e135d --- /dev/null +++ b/haystack_experimental/components/retrievers/opensearch/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .bm25_retriever import OpenSearchBM25Retriever +from .embedding_retriever import OpenSearchEmbeddingRetriever + +__all__ = ["OpenSearchBM25Retriever", "OpenSearchEmbeddingRetriever"] diff --git a/haystack_experimental/components/retrievers/opensearch/bm25_retriever.py b/haystack_experimental/components/retrievers/opensearch/bm25_retriever.py new file mode 100644 index 00000000..5a0c4013 --- /dev/null +++ b/haystack_experimental/components/retrievers/opensearch/bm25_retriever.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_experimental.document_stores.opensearch import OpenSearchDocumentStore + +logger = logging.getLogger(__name__) + + +@component +class OpenSearchBM25Retriever: + """ + OpenSearch BM25 retriever with async support. + """ + + def __init__( + self, + *, + document_store: OpenSearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + fuzziness: str = "AUTO", + top_k: int = 10, + scale_score: bool = False, + all_terms_must_match: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + custom_query: Optional[Dict[str, Any]] = None, + raise_on_failure: bool = True, + ): + """ + Creates the OpenSearchBM25Retriever component. + + :param document_store: An instance of OpenSearchDocumentStore to use with the Retriever. + :param filters: Filters to narrow down the search for documents in the Document Store. + :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. + For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param top_k: Maximum number of documents to return. + :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. + This is useful when comparing documents across different indexes. + :param all_terms_must_match: If `True`, all terms in the query string must be present in the + retrieved documents. This is useful when searching for short text where even one term + can make a difference. + :param filter_policy: Policy to determine how filters are applied. Possible options: + - `replace`: Runtime filters replace initialization filters. Use this policy to change the filtering scope + for specific queries. + - `merge`: Runtime filters are merged with initialization filters. + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder. + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + An example `run()` method for this `custom_query`: + + ```python + retriever.run( + query="Why did the revenue increase?", + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.years", "operator": "==", "value": "2019"}, + {"field": "meta.quarters", "operator": "in", "value": ["Q1", "Q2"]}, + ], + }, + ) + ``` + :param raise_on_failure: + Whether to raise an exception if the API call fails. Otherwise log a warning and return an empty list. + + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. + + """ + if not isinstance(document_store, OpenSearchDocumentStore): + msg = "document_store must be an instance of OpenSearchDocumentStore" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters or {} + self._fuzziness = fuzziness + self._top_k = top_k + self._scale_score = scale_score + self._all_terms_must_match = all_terms_must_match + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._custom_query = custom_query + self._raise_on_failure = raise_on_failure + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + fuzziness=self._fuzziness, + top_k=self._top_k, + scale_score=self._scale_score, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + custom_query=self._custom_query, + raise_on_failure=self._raise_on_failure, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + def _prepare_bm25_args( + self, + *, + query: str, + filters: Optional[Dict[str, Any]], + all_terms_must_match: Optional[bool], + top_k: Optional[int], + fuzziness: Optional[str], + scale_score: Optional[bool], + custom_query: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + + if filters is None: + filters = self._filters + if all_terms_must_match is None: + all_terms_must_match = self._all_terms_must_match + if top_k is None: + top_k = self._top_k + if fuzziness is None: + fuzziness = self._fuzziness + if scale_score is None: + scale_score = self._scale_score + if custom_query is None: + custom_query = self._custom_query + + return { + "query": query, + "filters": filters, + "fuzziness": fuzziness, + "top_k": top_k, + "scale_score": scale_score, + "all_terms_must_match": all_terms_must_match, + "custom_query": custom_query, + } + + @component.output_types(documents=List[Document]) + def run( # pylint: disable=too-many-positional-arguments + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + all_terms_must_match: Optional[bool] = None, + top_k: Optional[int] = None, + fuzziness: Optional[str] = None, + scale_score: Optional[bool] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): + """ + Retrieve documents using BM25 retrieval. + + :param query: The query string. + :param filters: Filters applied to the retrieved documents. The way runtime filters are applied depends on + the `filter_policy` specified at Retriever's initialization. + :param all_terms_must_match: If `True`, all terms in the query string must be present in the + retrieved documents. + :param top_k: Maximum number of documents to return. + :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. + For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. + This is useful when comparing documents across different indexes. + :param custom_query: A custom OpenSearch query. It must include a `$query` and may optionally + include a `$filters` placeholder. + + :returns: + A dictionary containing the retrieved documents with the following structure: + - documents: List of retrieved Documents. + + """ + docs: List[Document] = [] + bm25_args = self._prepare_bm25_args( + query=query, + filters=filters, + all_terms_must_match=all_terms_must_match, + top_k=top_k, + fuzziness=fuzziness, + scale_score=scale_score, + custom_query=custom_query, + ) + try: + docs = self._document_store._bm25_retrieval(**bm25_args) + except Exception as e: + if self._raise_on_failure: + raise e + logger.warning( + "An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}", + error=str(e), + exc_info=True, + ) + + return {"documents": docs} + + @component.output_types(documents=List[Document]) + async def run_async( # pylint: disable=too-many-positional-arguments + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + all_terms_must_match: Optional[bool] = None, + top_k: Optional[int] = None, + fuzziness: Optional[str] = None, + scale_score: Optional[bool] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): + """ + Retrieve documents using BM25 retrieval. + + :param query: The query string. + :param filters: Filters applied to the retrieved documents. The way runtime filters are applied depends on + the `filter_policy` specified at Retriever's initialization. + :param all_terms_must_match: If `True`, all terms in the query string must be present in the + retrieved documents. + :param top_k: Maximum number of documents to return. + :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. + For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. + This is useful when comparing documents across different indexes. + :param custom_query: A custom OpenSearch query. It must include a `$query` and may optionally + include a `$filters` placeholder. + + :returns: + A dictionary containing the retrieved documents with the following structure: + - documents: List of retrieved Documents. + + """ + docs: List[Document] = [] + bm25_args = self._prepare_bm25_args( + query=query, + filters=filters, + all_terms_must_match=all_terms_must_match, + top_k=top_k, + fuzziness=fuzziness, + scale_score=scale_score, + custom_query=custom_query, + ) + try: + docs = await self._document_store._bm25_retrieval_async(**bm25_args) + except Exception as e: + if self._raise_on_failure: + raise e + logger.warning( + "An error during BM25 retrieval occurred and will be ignored by returning empty results: {error}", + error=str(e), + exc_info=True, + ) + + return {"documents": docs} diff --git a/haystack_experimental/components/retrievers/opensearch/embedding_retriever.py b/haystack_experimental/components/retrievers/opensearch/embedding_retriever.py new file mode 100644 index 00000000..3f26c329 --- /dev/null +++ b/haystack_experimental/components/retrievers/opensearch/embedding_retriever.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_experimental.document_stores.opensearch import OpenSearchDocumentStore + +logger = logging.getLogger(__name__) + + +@component +class OpenSearchEmbeddingRetriever: + """ + OpenSearch embedding retriever with async support. + """ + + def __init__( + self, + *, + document_store: OpenSearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + custom_query: Optional[Dict[str, Any]] = None, + raise_on_failure: bool = True, + ): + """ + Create the OpenSearchEmbeddingRetriever component. + + :param document_store: An instance of OpenSearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. Possible options: + - `merge`: Runtime filters are merged with initialization filters. + - `replace`: Runtime filters replace initialization filters. Use this policy to change the filtering scope. + :param custom_query: The custom OpenSearch query containing a mandatory `$query_embedding` and + an optional `$filters` placeholder. + :param raise_on_failure: + If `True`, raises an exception if the API call fails. + If `False`, logs a warning and returns an empty list. + + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. + """ + if not isinstance(document_store, OpenSearchDocumentStore): + msg = "document_store must be an instance of OpenSearchDocumentStore" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._custom_query = custom_query + self._raise_on_failure = raise_on_failure + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + custom_query=self._custom_query, + raise_on_failure=self._raise_on_failure, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): + """ + Retrieve documents using a vector similarity metric. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever + returns `top_k` matching documents. + The way runtime filters are applied depends on the `filter_policy` selected when initializing the Retriever. + :param top_k: Maximum number of documents to return. + :param custom_query: A custom OpenSearch query containing a mandatory `$query_embedding` and an + optional `$filters` placeholder. + :returns: + Dictionary with key "documents" containing the retrieved Documents. + - documents: List of Document similar to `query_embedding`. + """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + top_k = top_k or self._top_k + if filters is None: + filters = self._filters + if top_k is None: + top_k = self._top_k + if custom_query is None: + custom_query = self._custom_query + + docs: List[Document] = [] + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + custom_query=custom_query, + ) + except Exception as e: + if self._raise_on_failure: + raise e + logger.warning( + "An error during embedding retrieval occurred and will be " + "ignored by returning empty results: {error}", + error=str(e), + exc_info=True, + ) + + return {"documents": docs} + + @component.output_types(documents=List[Document]) + async def run_async( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): + """ + Retrieve documents using a vector similarity metric. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever + returns `top_k` matching documents. + The way runtime filters are applied depends on the `filter_policy` selected when initializing the Retriever. + :param top_k: Maximum number of documents to return. + :param custom_query: A custom OpenSearch query containing a mandatory `$query_embedding` and an + optional `$filters` placeholder. + :returns: + Dictionary with key "documents" containing the retrieved Documents. + - documents: List of Document similar to `query_embedding`. + """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + top_k = top_k or self._top_k + if filters is None: + filters = self._filters + if top_k is None: + top_k = self._top_k + if custom_query is None: + custom_query = self._custom_query + + docs: List[Document] = [] + + try: + docs = await self._document_store._embedding_retrieval_async( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + custom_query=custom_query, + ) + except Exception as e: + if self._raise_on_failure: + raise e + logger.warning( + "An error during embedding retrieval occurred and will be " + "ignored by returning empty results: {error}", + error=str(e), + exc_info=True, + ) + + return {"documents": docs} diff --git a/haystack_experimental/components/writers/__init__.py b/haystack_experimental/components/writers/__init__.py index 187d1b91..8602ec74 100644 --- a/haystack_experimental/components/writers/__init__.py +++ b/haystack_experimental/components/writers/__init__.py @@ -2,6 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.components.writers.chat_message_writer import ChatMessageWriter +from haystack_experimental.components.writers.chat_message_writer import ( + ChatMessageWriter, +) +from haystack_experimental.components.writers.document_writer import DocumentWriter -_all_ = ["ChatMessageWriter"] +_all_ = ["ChatMessageWriter", "DocumentWriter"] diff --git a/haystack_experimental/components/writers/document_writer.py b/haystack_experimental/components/writers/document_writer.py new file mode 100644 index 00000000..9dcc1fe3 --- /dev/null +++ b/haystack_experimental/components/writers/document_writer.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from haystack import Document, component, logging +from haystack.components.writers import DocumentWriter as DocumentWriterBase +from haystack.document_stores.types import DuplicatePolicy + +from haystack_experimental.document_stores.types import DocumentStore + +logger = logging.getLogger(__name__) + + +@component +class DocumentWriter(DocumentWriterBase): + """ + Writes documents to a DocumentStore. + + ### Usage example + ```python + from haystack import Document + from haystack.components.writers import DocumentWriter + from haystack.document_stores.in_memory import InMemoryDocumentStore + + docs = [ + Document(content="Python is a popular programming language"), + ] + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs) + ``` + """ + + def __init__( + self, + document_store: DocumentStore, + policy: DuplicatePolicy = DuplicatePolicy.NONE, + ): + """ + Create a DocumentWriter component. + + :param document_store: + The instance of the document store where you want to store your documents. + :param policy: + The policy to apply when a Document with the same ID already exists in the DocumentStore. + - `DuplicatePolicy.NONE`: Default policy, relies on the DocumentStore settings. + - `DuplicatePolicy.SKIP`: Skips documents with the same ID and doesn't write them to the DocumentStore. + - `DuplicatePolicy.OVERWRITE`: Overwrites documents with the same ID. + - `DuplicatePolicy.FAIL`: Raises an error if a Document with the same ID is already in the DocumentStore. + """ + super(DocumentWriter, self).__init__(document_store=document_store, policy=policy) + + @component.output_types(documents_written=int) + async def run_async(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None): + """ + Run the DocumentWriter on the given input data. + + :param documents: + A list of documents to write to the document store. + :param policy: + The policy to use when encountering duplicate documents. + :returns: + Number of documents written to the document store. + + :raises ValueError: + If the specified document store is not found. + """ + if policy is None: + policy = self.policy + + if not hasattr(self.document_store, "write_documents_async"): + raise TypeError(f"Document store {type(self.document_store).__name__} does not provide async support.") + + documents_written = await self.document_store.write_documents_async( # type: ignore + documents=documents, policy=policy + ) + return {"documents_written": documents_written} diff --git a/haystack_experimental/core/__init__.py b/haystack_experimental/core/__init__.py new file mode 100644 index 00000000..efafd8c9 --- /dev/null +++ b/haystack_experimental/core/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .pipeline import AsyncPipeline, run_async_pipeline + +_all_ = ["AsyncPipeline", "run_async_pipeline"] diff --git a/haystack_experimental/core/pipeline/__init__.py b/haystack_experimental/core/pipeline/__init__.py new file mode 100644 index 00000000..c755a510 --- /dev/null +++ b/haystack_experimental/core/pipeline/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .async_pipeline import AsyncPipeline, run_async_pipeline + +__all__ = ["AsyncPipeline", "run_async_pipeline"] diff --git a/haystack_experimental/core/pipeline/async_pipeline.py b/haystack_experimental/core/pipeline/async_pipeline.py new file mode 100644 index 00000000..2270181d --- /dev/null +++ b/haystack_experimental/core/pipeline/async_pipeline.py @@ -0,0 +1,584 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from typing import Any, AsyncIterator, Dict, List, Mapping, Optional, Set, Tuple +from warnings import warn + +import networkx as nx +from haystack import logging, tracing +from haystack.core.component import Component +from haystack.core.errors import PipelineMaxComponentRuns, PipelineRuntimeError +from haystack.core.pipeline.base import ( + PipelineBase, + _add_missing_input_defaults, + _dequeue_component, + _dequeue_waiting_component, + _enqueue_component, + _enqueue_waiting_component, + _is_lazy_variadic, +) +from haystack.telemetry import pipeline_running + +logger = logging.getLogger(__name__) + + +class AsyncPipeline(PipelineBase): + """ + Asynchronous version of the orchestration engine. + + The primary difference between this and the synchronous version is that this version + will attempt to execute a component's async `run_async` method if it exists. If it doesn't, + the synchronous `run` method is executed as an awaitable task on a thread pool executor. This + version also eagerly yields the output of each component as soon as it is available. + """ + + def __init__( + self, + metadata: Optional[Dict[str, Any]] = None, + max_runs_per_component: int = 100, + async_executor: Optional[ThreadPoolExecutor] = None, + ): + """ + Creates the asynchronous Pipeline. + + :param metadata: + Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in + this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file. + :param max_runs_per_component: + How many times the `Pipeline` can run the same Component. + If this limit is reached a `PipelineMaxComponentRuns` exception is raised. + If not set defaults to 100 runs per Component. + :param async_executor: + Optional ThreadPoolExecutor to use for running synchronous components. If not provided, a single-threaded + executor will initialized and used. + """ + super().__init__(metadata, max_runs_per_component) + + # We only need one thread as we'll immediately block after launching it. + self.executor = ( + ThreadPoolExecutor(thread_name_prefix=f"async-pipeline-executor-{id(self)}", max_workers=1) + if async_executor is None + else async_executor + ) + + async def _run_component( + self, + name: str, + inputs: Dict[str, Any], + parent_span: Optional[tracing.Span] = None, + ) -> Dict[str, Any]: + """ + Runs a Component with the given inputs. + + :param name: Name of the Component as defined in the Pipeline. + :param inputs: Inputs for the Component. + :param parent_span: The parent span to use for the newly created span. + This is to allow tracing to be correctly linked to the pipeline run. + :raises PipelineRuntimeError: If Component doesn't return a dictionary. + :return: The output of the Component. + """ + instance: Component = self.graph.nodes[name]["instance"] + + with tracing.tracer.trace( + "haystack.component.run", + tags={ + "haystack.component.name": name, + "haystack.component.type": instance.__class__.__name__, + "haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()}, + "haystack.component.input_spec": { + key: { + "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "senders": value.senders, + } + for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore + }, + "haystack.component.output_spec": { + key: { + "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), + "receivers": value.receivers, + } + for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore + }, + }, + parent_span=parent_span, + ) as span: + # We deepcopy the inputs otherwise we might lose that information + # when we delete them in case they're sent to other Components + span.set_content_tag("haystack.component.input", deepcopy(inputs)) + logger.info("Running component {component_name}", component_name=name) + + res: Dict[str, Any] + if instance.__haystack_supports_async__: # type: ignore + logger.info("Running async component {component_name}", component_name=name) + res = await instance.run_async(**inputs) # type: ignore + else: + logger.info( + "Running sync component {component_name} on executor", + component_name=name, + ) + res = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: instance.run(**inputs)) + self.graph.nodes[name]["visits"] += 1 + + # After a Component that has variadic inputs is run, we need to reset the variadic inputs that were consumed + for socket in instance.__haystack_input__._sockets_dict.values(): # type: ignore + if socket.name not in inputs: + continue + if socket.is_variadic: + inputs[socket.name] = [] + + if not isinstance(res, Mapping): + raise PipelineRuntimeError( + f"Component '{name}' didn't return a dictionary. " + "Components must always return dictionaries: check the documentation." + ) + span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"]) + span.set_content_tag("haystack.component.output", res) + + return res + + async def _run_subgraph( # noqa: PLR0915, PLR0912 # pylint: disable=too-many-locals, too-many-branches, too-many-statements + self, + cycle: List[str], + component_name: str, + components_inputs: Dict[str, Dict[str, Any]], + ) -> AsyncIterator[Tuple[Dict[str, Any], bool]]: + """ + Runs a `cycle` in the Pipeline starting from `component_name`. + + This will return once there are no inputs for the Components in `cycle`. + + This is an internal method meant to be used in `Pipeline.run()` only. + + :param cycle: + List of Components that are part of the cycle being run + :param component_name: + Name of the Component that will start execution of the cycle + :param components_inputs: + Components inputs, this might include inputs for Components that are not part + of the cycle but part of the wider Pipeline's graph + :yields: + Yields the individual output of each component after its execution and the final + outputs of all the Components that are not connected to other Components in `cycle`. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline + """ + waiting_queue: List[Tuple[str, Component]] = [] + run_queue: List[Tuple[str, Component]] = [] + + # Create the run queue starting with the component that needs to run first + start_index = cycle.index(component_name) + for node in cycle[start_index:]: + run_queue.append((node, self.graph.nodes[node]["instance"])) + + before_last_waiting_queue: Optional[Set[str]] = None + last_waiting_queue: Optional[Set[str]] = None + + subgraph_outputs = {} + + # This variable is used to keep track if we still need to run the cycle or not. + # When a Component doesn't send outputs to another Component + # that's inside the subgraph, we stop running this subgraph. + cycle_received_inputs = False + + while not cycle_received_inputs: + # Here we run the Components + name, comp = run_queue.pop(0) + if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue): + # We run Components with lazy variadic inputs only if there only Components with + # lazy variadic inputs left to run + _enqueue_waiting_component((name, comp), waiting_queue) + continue + + # As soon as a Component returns only output that is not part of the cycle, we can stop + if self._component_has_enough_inputs_to_run(name, components_inputs): + if self.graph.nodes[name]["visits"] > self._max_runs_per_component: + msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" + raise PipelineMaxComponentRuns(msg) + + res: Dict[str, Any] = await self._run_component(name, components_inputs[name]) + yield {name: deepcopy(res)}, False + + # Delete the inputs that were consumed by the Component and are not received from + # the user or from Components that are part of this cycle + sockets = list(components_inputs[name].keys()) + for socket_name in sockets: + senders = comp.__haystack_input__._sockets_dict[socket_name].senders # type: ignore + if not senders: + # We keep inputs that came from the user + continue + all_senders_in_cycle = all(sender in cycle for sender in senders) + if all_senders_in_cycle: + # All senders are in the cycle, we can remove the input. + # We'll receive it later at a certain point. + del components_inputs[name][socket_name] + + # Reset the waiting for input previous states, we managed to run a component + before_last_waiting_queue = None + last_waiting_queue = None + + # Check if a component doesn't send any output to components that are part of the cycle + final_output_reached = False + for output_socket in res.keys(): + for receiver in comp.__haystack_output__._sockets_dict[output_socket].receivers: # type: ignore + if receiver in cycle: + final_output_reached = True + break + if final_output_reached: + break + + if not final_output_reached: + # We stop only if the Component we just ran doesn't send any output to sockets that + # are part of the cycle + cycle_received_inputs = True + + # We manage to run this component that was in the waiting list, we can remove it. + # This happens when a component was put in the waiting list but we reached it from another edge. + _dequeue_waiting_component((name, comp), waiting_queue) + for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): + _dequeue_component(pair, run_queue, waiting_queue) + + receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle] + + res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) + + # We treat a cycle as a completely independent graph, so we keep track of output + # that is not sent inside the cycle. + # This output is going to get distributed to the wider graph after we finish running + # a cycle. + # All values that are left at this point go outside the cycle. + if len(res) > 0: + subgraph_outputs[name] = res + else: + # This component doesn't have enough inputs so we can't run it yet + _enqueue_waiting_component((name, comp), waiting_queue) + + if len(run_queue) == 0 and len(waiting_queue) > 0: + # Check if we're stuck in a loop. + # It's important to check whether previous waitings are None as it could be that no + # Component has actually been run yet. + if ( + before_last_waiting_queue is not None + and last_waiting_queue is not None + and before_last_waiting_queue == last_waiting_queue + ): + if self._is_stuck_in_a_loop(waiting_queue): + # We're stuck! We can't make any progress. + msg = ( + "Pipeline is stuck running in a loop. Partial outputs will be returned. " + "Check the Pipeline graph for possible issues." + ) + warn(RuntimeWarning(msg)) + break + + (name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + continue + + before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None + last_waiting_queue = {item[0] for item in waiting_queue} + + (name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + + yield subgraph_outputs, True + + async def run( # noqa: PLR0915, PLR0912 # pylint: disable=too-many-locals, too-many-branches, too-many-statements + self, + data: Dict[str, Any], + ) -> AsyncIterator[Dict[str, Any]]: + """ + Runs the pipeline with given input data. + + Since the return value of this function is an asynchroneous generator, the + execution will only progress if the generator is consumed. + + + :param data: + A dictionary of inputs for the pipeline's components. Each key is a component name + and its value is a dictionary of that component's input parameters: + ``` + data = { + "comp1": {"input1": 1, "input2": 2}, + } + ``` + For convenience, this format is also supported when input names are unique: + ``` + data = { + "input1": 1, + "input2": 2, + } + ``` + + :yields: + A dictionary where each entry corresponds to a component name and its + output. Outputs of each component are yielded as soon as they are available, + and the final output is yielded as the final dictionary. + + :raises PipelineRuntimeError: + If a component fails or returns unexpected output. + + Example a - Using named components: + Consider a 'Hello' component that takes a 'word' input and outputs a greeting. + + ```python + @component + class Hello: + @component.output_types(output=str) + def run(self, word: str): + return {"output": f"Hello, {word}!"} + ``` + + Create a pipeline with two 'Hello' components connected together: + + ```python + pipeline = Pipeline() + pipeline.add_component("hello", Hello()) + pipeline.add_component("hello2", Hello()) + pipeline.connect("hello.output", "hello2.word") + + async for result in pipeline.run(data={"hello": {"word": "world"}}): + print(result) + ``` + + This will return the results in the following order: + {"hello": "Hello, world!"} # Output of the first component + {"hello2": "Hello, Hello, world!"} # Output of the second component + {"hello2": "Hello, Hello, world!"} # Final output of the pipeline + """ + + pipeline_running(self) + + # Reset the visits count for each component + self._init_graph() + + # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not + # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() + self.warm_up() + + # normalize `data` + data = self._prepare_component_input_data(data) + + # Raise if input is malformed in some way + self._validate_input(data) + + # Normalize the input data + components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data) + + # These variables are used to detect when we're stuck in a loop. + # Stuck loops can happen when one or more components are waiting for input but + # no other component is going to run. + # This can happen when a whole branch of the graph is skipped for example. + # When we find that two consecutive iterations of the loop where the waiting_queue is the same, + # we know we're stuck in a loop and we can't make any progress. + # + # They track the previous two states of the waiting_queue. So if waiting_queue would n, + # before_last_waiting_queue would be n-2 and last_waiting_queue would be n-1. + # When we run a component, we reset both. + before_last_waiting_queue: Optional[Set[str]] = None + last_waiting_queue: Optional[Set[str]] = None + + # The waiting_for_input list is used to keep track of components that are waiting for input. + waiting_queue: List[Tuple[str, Component]] = [] + + # This is what we'll return at the end + final_outputs: Dict[Any, Any] = {} + + # Break cycles in case there are, this is a noop if no cycle is found. + # This will raise if a cycle can't be broken. + graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph() + + run_queue: List[Tuple[str, Component]] = [] + for node in nx.topological_sort(graph_without_cycles): + run_queue.append((node, self.graph.nodes[node]["instance"])) + + # Set defaults inputs for those sockets that don't receive input neither from the user + # nor from other Components. + # If they have no default nothing is done. + # This is important to ensure correct order execution, otherwise some variadic + # Components that receive input from the user might be run before than they should. + for name, comp in self.graph.nodes(data="instance"): + if name not in components_inputs: + components_inputs[name] = {} + for socket_name, socket in comp.__haystack_input__._sockets_dict.items(): + if socket_name in components_inputs[name]: + continue + if not socket.senders: + value = socket.default_value + if socket.is_variadic: + value = [value] + components_inputs[name][socket_name] = value + + with tracing.tracer.trace( + "haystack.pipeline.run", + tags={ + "haystack.pipeline.input_data": data, + "haystack.pipeline.output_data": final_outputs, + "haystack.pipeline.metadata": self.metadata, + "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, + }, + ) as span: + while len(run_queue) > 0: + name, comp = run_queue.pop(0) + + if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue): + # We run Components with lazy variadic inputs only if there only Components with + # lazy variadic inputs left to run + _enqueue_waiting_component((name, comp), waiting_queue) + continue + if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get( + name, [] + ): + cycles = components_in_cycles.get(name, []) + + # This component is part of one or more cycles, let's get the first one and run it. + # We can reliably pick any of the cycles if there are multiple ones, the way cycles + # are run doesn't make a different whether we pick the first or any of the others a + # Component is part of. + async for subgraph_output, is_final_output in self._run_subgraph( + cycles[0], name, components_inputs + ): + if not is_final_output: + yield subgraph_output + assert is_final_output is True + + # After a cycle is run the previous run_queue can't be correct anymore cause it's + # not modified when running the subgraph. + # So we reset it given the output returned by the subgraph. + run_queue = [] + + # Reset the waiting for input previous states, we managed to run at least one component + before_last_waiting_queue = None + last_waiting_queue = None + + for component_name, component_output in subgraph_output.items(): + receivers = self._find_receivers_from(component_name) + component_output = self._distribute_output( + receivers, + component_output, + components_inputs, + run_queue, + waiting_queue, + ) + + if len(component_output) > 0: + final_outputs[component_name] = component_output + + elif self._component_has_enough_inputs_to_run(name, components_inputs): + if self.graph.nodes[name]["visits"] > self._max_runs_per_component: + msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'" + raise PipelineMaxComponentRuns(msg) + + res: Dict[str, Any] = await self._run_component(name, components_inputs[name], parent_span=span) + yield {name: deepcopy(res)} + + # Delete the inputs that were consumed by the Component and are not received from the user + sockets = list(components_inputs[name].keys()) + for socket_name in sockets: + senders = comp.__haystack_input__._sockets_dict[socket_name].senders + if senders: + # Delete all inputs that are received from other Components + del components_inputs[name][socket_name] + # We keep inputs that came from the user + + # Reset the waiting for input previous states, we managed to run a component + before_last_waiting_queue = None + last_waiting_queue = None + + # We manage to run this component that was in the waiting list, we can remove it. + # This happens when a component was put in the waiting list but we reached it from another edge. + _dequeue_waiting_component((name, comp), waiting_queue) + + for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs): + _dequeue_component(pair, run_queue, waiting_queue) + receivers = self._find_receivers_from(name) + res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue) + + if len(res) > 0: + final_outputs[name] = res + else: + # This component doesn't have enough inputs so we can't run it yet + _enqueue_waiting_component((name, comp), waiting_queue) + + if len(run_queue) == 0 and len(waiting_queue) > 0: + # Check if we're stuck in a loop. + # It's important to check whether previous waitings are None as it could be that no + # Component has actually been run yet. + if ( + before_last_waiting_queue is not None + and last_waiting_queue is not None + and before_last_waiting_queue == last_waiting_queue + ): + if self._is_stuck_in_a_loop(waiting_queue): + # We're stuck! We can't make any progress. + msg = ( + "Pipeline is stuck running in a loop. Partial outputs will be returned. " + "Check the Pipeline graph for possible issues." + ) + warn(RuntimeWarning(msg)) + break + + (name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + continue + + before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None + last_waiting_queue = {item[0] for item in waiting_queue} + + (name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue) + _add_missing_input_defaults(name, comp, components_inputs) + _enqueue_component((name, comp), run_queue, waiting_queue) + + yield final_outputs + + +async def run_async_pipeline( + pipeline: AsyncPipeline, + data: Dict[str, Any], + include_outputs_from: Optional[Set[str]] = None, +) -> Dict[str, Any]: + """ + Helper function to run an asynchronous pipeline and return the final outputs and specific intermediate outputs. + + For fine-grained control over the intermediate outputs, use the `AsyncPipeline.run()` + method directly and consume the generator. + + :param pipeline: + Asynchronous pipeline to run. + :param data: + Input data for the pipeline. + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + :returns: + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. + """ + outputs = [x async for x in pipeline.run(data)] + + intermediate_outputs = { + k: v for d in outputs[:-1] for k, v in d.items() if include_outputs_from is None or k in include_outputs_from + } + final_output = outputs[-1] + + # Same logic used in the sync pipeline to accumulate extra outputs. + for name, output in intermediate_outputs.items(): + inner = final_output.get(name) + if inner is None: + final_output[name] = output + else: + for k, v in output.items(): + if k not in inner: + inner[k] = v + + return final_output diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index e086fcb8..78a97618 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -10,6 +10,20 @@ ToolCall, ToolCallResult, ) +from haystack_experimental.dataclasses.streaming_chunk import ( + AsyncStreamingCallbackT, + StreamingCallbackT, +) from haystack_experimental.dataclasses.tool import Tool -__all__ = ["ChatMessage", "ChatRole", "ToolCall", "ToolCallResult", "TextContent", "ChatMessageContentT", "Tool"] +__all__ = [ + "AsyncStreamingCallbackT", + "ChatMessage", + "ChatRole", + "StreamingCallbackT", + "ToolCall", + "ToolCallResult", + "TextContent", + "ChatMessageContentT", + "Tool", +] diff --git a/haystack_experimental/dataclasses/streaming_chunk.py b/haystack_experimental/dataclasses/streaming_chunk.py new file mode 100644 index 00000000..b9b193e6 --- /dev/null +++ b/haystack_experimental/dataclasses/streaming_chunk.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Awaitable, Callable, Optional, Union + +from haystack.dataclasses import StreamingChunk + +from haystack_experimental.util import is_callable_async_compatible + +StreamingCallbackT = Callable[[StreamingChunk], None] +AsyncStreamingCallbackT = Callable[[StreamingChunk], Awaitable[None]] + + +def select_streaming_callback( + init_callback: Optional[Union[StreamingCallbackT, AsyncStreamingCallbackT]], + runtime_callback: Optional[Union[StreamingCallbackT, AsyncStreamingCallbackT]], + requires_async: bool, +) -> Optional[Union[StreamingCallbackT, AsyncStreamingCallbackT]]: + """ + Picks the correct streaming callback given an optional initial and runtime callback. + + The runtime callback takes precedence over the initial callback. + + :param init_callback: + The initial callback. + :param runtime_callback: + The runtime callback. + :param requires_async: + Whether the selected callback must be async compatible. + :returns: + The selected callback. + """ + if init_callback is not None: + if requires_async and not is_callable_async_compatible(init_callback): + raise ValueError("The init callback must be async compatible.") + if not requires_async and is_callable_async_compatible(init_callback): + raise ValueError("The init callback cannot be a coroutine.") + + if runtime_callback is not None: + if requires_async and not is_callable_async_compatible(runtime_callback): + raise ValueError("The runtime callback must be async compatible.") + if not requires_async and is_callable_async_compatible(runtime_callback): + raise ValueError("The runtime callback cannot be a coroutine.") + + return runtime_callback or init_callback diff --git a/haystack_experimental/document_stores/__init__.py b/haystack_experimental/document_stores/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/document_stores/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/document_stores/in_memory/__init__.py b/haystack_experimental/document_stores/in_memory/__init__.py new file mode 100644 index 00000000..2bc0776b --- /dev/null +++ b/haystack_experimental/document_stores/in_memory/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .document_store import InMemoryDocumentStore + +__all__ = ["InMemoryDocumentStore"] diff --git a/haystack_experimental/document_stores/in_memory/document_store.py b/haystack_experimental/document_stores/in_memory/document_store.py new file mode 100644 index 00000000..35128b56 --- /dev/null +++ b/haystack_experimental/document_stores/in_memory/document_store.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Literal, Optional + +from haystack.dataclasses import Document +from haystack.document_stores.in_memory.document_store import ( + InMemoryDocumentStore as InMemoryDocumentStoreBase, +) +from haystack.document_stores.types import DuplicatePolicy + + +class InMemoryDocumentStore(InMemoryDocumentStoreBase): + """ + Asynchronous version of the in-memory document store. + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", + bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L", + bm25_parameters: Optional[Dict] = None, + embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product", + index: Optional[str] = None, + async_executor: Optional[ThreadPoolExecutor] = None, + ): + """ + Initializes the DocumentStore. + + :param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval. + :param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus". + :param bm25_parameters: Parameters for BM25 implementation in a dictionary format. + For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25} + You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25. + :param embedding_similarity_function: The similarity function used to compare Documents embeddings. + One of "dot_product" (default) or "cosine". To choose the most appropriate function, look for information + about your embedding model. + :param index: A specific index to store the documents. If not specified, a random UUID is used. + Using the same index allows you to store documents across multiple InMemoryDocumentStore instances. + :param async_executor: + Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded + executor will initialized and used. + """ + super().__init__( + bm25_tokenization_regex=bm25_tokenization_regex, + bm25_algorithm=bm25_algorithm, + bm25_parameters=bm25_parameters, + embedding_similarity_function=embedding_similarity_function, + index=index, + ) + + self.executor = ( + ThreadPoolExecutor( + thread_name_prefix=f"async-inmemory-docstore-executor-{id(self)}", + max_workers=1, + ) + if async_executor is None + else async_executor + ) + + async def count_documents_async(self) -> int: + """ + Returns the number of how many documents are present in the DocumentStore. + """ + return len(self.storage.keys()) + + async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol + documentation. + + :param filters: The filters to apply to the document list. + :returns: A list of Documents that match the given filters. + """ + return await asyncio.get_event_loop().run_in_executor( + self.executor, lambda: self.filter_documents(filters=filters) + ) + + async def write_documents_async( + self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE + ) -> int: + """ + Refer to the DocumentStore.write_documents() protocol documentation. + + If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`. + """ + return await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: self.write_documents(documents=documents, policy=policy), + ) + + async def delete_documents_async(self, document_ids: List[str]) -> None: + """ + Deletes all documents with matching document_ids from the DocumentStore. + + :param document_ids: The object_ids to delete. + """ + await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: self.delete_documents(document_ids=document_ids), + ) + + async def bm25_retrieval_async( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = False, + ) -> List[Document]: + """ + Retrieves documents that are most relevant to the query using BM25 algorithm. + + :param query: The query string. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The number of top documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents. Default is False. + :returns: A list of the top_k documents most relevant to the query. + """ + return await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: self.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score), + ) + + async def embedding_retrieval_async( # pylint: disable=too-many-positional-arguments + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = False, + return_embedding: bool = False, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + + :param query_embedding: Embedding of the query. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The number of top documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False. + :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. + :returns: A list of the top_k documents most relevant to the query. + """ + return await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: self.embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + ), + ) diff --git a/haystack_experimental/document_stores/opensearch/__init__.py b/haystack_experimental/document_stores/opensearch/__init__.py new file mode 100644 index 00000000..44d4f9d1 --- /dev/null +++ b/haystack_experimental/document_stores/opensearch/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .document_store import OpenSearchDocumentStore + +__all__ = ["OpenSearchDocumentStore"] diff --git a/haystack_experimental/document_stores/opensearch/document_store.py b/haystack_experimental/document_stores/opensearch/document_store.py new file mode 100644 index 00000000..b6e49efe --- /dev/null +++ b/haystack_experimental/document_stores/opensearch/document_store.py @@ -0,0 +1,596 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any, Dict, List, Optional + +from haystack import default_from_dict, default_to_dict, logging +from haystack.dataclasses import Document +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.lazy_imports import LazyImport +from haystack.utils.filters import raise_on_invalid_filter_syntax +from opensearchpy import AsyncOpenSearch, OpenSearch +from opensearchpy.helpers import async_bulk, bulk + +logger = logging.getLogger(__name__) + +with LazyImport("Run 'pip install opensearch-haystack opensearch-py[async]'") as opensearch_import: + # pylint: disable=import-error + from haystack_integrations.document_stores.opensearch.auth import AWSAuth + from haystack_integrations.document_stores.opensearch.document_store import ( + BM25_SCALING_FACTOR, + DEFAULT_MAX_CHUNK_BYTES, + DEFAULT_SETTINGS, + Hosts, + ) + from haystack_integrations.document_stores.opensearch.filters import ( + normalize_filters, + ) + + +class OpenSearchDocumentStore: + def __init__( # pylint: disable=dangerous-default-value + self, + *, + hosts: Optional[Hosts] = None, + index: str = "default", + max_chunk_bytes: int = DEFAULT_MAX_CHUNK_BYTES, + embedding_dim: int = 768, + return_embedding: bool = False, + method: Optional[Dict[str, Any]] = None, + mappings: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = DEFAULT_SETTINGS, + create_index: bool = True, + http_auth: Any = None, + use_ssl: Optional[bool] = None, + verify_certs: Optional[bool] = None, + timeout: Optional[int] = None, + **kwargs, + ): + """ + Creates a new OpenSearchDocumentStore instance. + + The `embeddings_dim`, `method`, `mappings`, and `settings` arguments are only used if the index does not + exists and needs to be created. If the index already exists, its current configurations will be used. + + For more information on connection parameters, see the [official OpenSearch documentation](https://opensearch.org/docs/latest/clients/python-low-level/#connecting-to-opensearch) + + :param hosts: List of hosts running the OpenSearch client. Defaults to None + :param index: Name of index in OpenSearch, if it doesn't exist it will be created. Defaults to "default" + :param max_chunk_bytes: Maximum size of the requests in bytes. Defaults to 100MB + :param embedding_dim: Dimension of the embeddings. Defaults to 768 + :param return_embedding: + Whether to return the embedding of the retrieved Documents. + :param method: The method definition of the underlying configuration of the approximate k-NN algorithm. Please + see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#method-definitions) + for more information. Defaults to None + :param mappings: The mapping of how the documents are stored and indexed. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/field-types/) + for more information. If None, it uses the embedding_dim and method arguments to create default mappings. + Defaults to None + :param settings: The settings of the index to be created. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#index-settings) + for more information. Defaults to {"index.knn": True} + :param create_index: Whether to create the index if it doesn't exist. Defaults to True + :param http_auth: http_auth param passed to the underying connection class. + For basic authentication with default connection class `Urllib3HttpConnection` this can be + - a tuple of (username, password) + - a list of [username, password] + - a string of "username:password" + For AWS authentication with `Urllib3HttpConnection` pass an instance of `AWSAuth`. + Defaults to None + :param use_ssl: Whether to use SSL. Defaults to None + :param verify_certs: Whether to verify certificates. Defaults to None + :param timeout: Timeout in seconds. Defaults to None + :param **kwargs: Optional arguments that `OpenSearch` takes. For the full list of supported kwargs, + see the [official OpenSearch reference](https://opensearch-project.github.io/opensearch-py/api-ref/clients/opensearch_client.html) + """ + self._hosts = hosts + self._index = index + self._max_chunk_bytes = max_chunk_bytes + self._embedding_dim = embedding_dim + self._return_embedding = return_embedding + self._method = method + self._mappings = mappings or self._get_default_mappings() + self._settings = settings + self._create_index = create_index + self._http_auth = http_auth + self._use_ssl = use_ssl + self._verify_certs = verify_certs + self._timeout = timeout + self._kwargs = kwargs + + # Client is initialized lazily to prevent side effects when + # the document store is instantiated. + self._client = None + self._async_client = None + self._initialized = False + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + hosts=self._hosts, + index=self._index, + max_chunk_bytes=self._max_chunk_bytes, + embedding_dim=self._embedding_dim, + method=self._method, + mappings=self._mappings, + settings=self._settings, + create_index=self._create_index, + return_embedding=self._return_embedding, + http_auth=(self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth), + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + if http_auth := data.get("init_parameters", {}).get( # noqa: SIM102 + "http_auth" + ): + if isinstance(http_auth, dict): + data["init_parameters"]["http_auth"] = AWSAuth.from_dict(http_auth) + + return default_from_dict(cls, data) + + def _ensure_initialized(self): + # Ideally, we have a warm-up stage for document stores as well as components. + if not self._initialized: + self._client = OpenSearch( + hosts=self._hosts, + http_auth=self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, + **self._kwargs, + ) + self._async_client = AsyncOpenSearch( + hosts=self._hosts, + http_auth=self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, + **self._kwargs, + ) + + self._initialized = True + + # In a just world, this is something that the document store shouldn't + # be responsible for. However, the current implementation has become a + # dependency of downstream users, so we'll have to preserve this behaviour + # (for now). + self._ensure_index_exists() + + def _ensure_index_exists(self): + assert self._client is not None + + if self._client.indices.exists(index=self._index): + logger.debug( + "The index '{index}' already exists. The `embedding_dim`, `method`, `mappings`, and " + "`settings` values will be ignored.", + index=self._index, + ) + elif self._create_index: + # Create the index if it doesn't exist + body = {"mappings": self._mappings, "settings": self._settings} + self._client.indices.create(index=self._index, body=body) # type:ignore + + def _get_default_mappings(self) -> Dict[str, Any]: + default_mappings: Dict[str, Any] = { + "properties": { + "embedding": { + "type": "knn_vector", + "index": True, + "dimension": self._embedding_dim, + }, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "match_mapping_type": "string", + "mapping": {"type": "keyword"}, + } + } + ], + } + if self._method: + default_mappings["properties"]["embedding"]["method"] = self._method + return default_mappings + + def create_index( # noqa: D102 + self, + index: Optional[str] = None, + mappings: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = None, + ) -> None: + self._ensure_initialized() + assert self._client is not None + + if not index: + index = self._index + if not mappings: + mappings = self._mappings + if not settings: + settings = self._settings + + if not self._client.indices.exists(index=index): + self._client.indices.create(index=index, body={"mappings": mappings, "settings": settings}) + + def count_documents(self) -> int: # noqa: D102 + self._ensure_initialized() + + assert self._client is not None + return self._client.count(index=self._index)["count"] + + async def count_documents_async(self) -> int: # noqa: D102 + self._ensure_initialized() + + assert self._async_client is not None + return (await self._async_client.count(index=self._index))["count"] + + def _deserialize_search_hits(self, hits: List[Dict[str, Any]]) -> List[Document]: + out = [] + for hit in hits: + data = hit["_source"] + if "highlight" in hit: + data["metadata"]["highlighted"] = hit["highlight"] + data["score"] = hit["_score"] + out.append(Document.from_dict(data)) + + return out + + def _prepare_filter_search_request(self, filters: Optional[Dict[str, Any]]) -> Dict[str, Any]: + raise_on_invalid_filter_syntax(filters) + search_kwargs: Dict[str, Any] = {"size": 10_000} + if filters: + search_kwargs["query"] = {"bool": {"filter": normalize_filters(filters)}} + return search_kwargs + + def _search_documents(self, request_body: Dict[str, Any]) -> List[Document]: + assert self._client is not None + search_results = self._client.search(index=self._index, body=request_body) + return self._deserialize_search_hits(search_results["hits"]["hits"]) + + async def _search_documents_async(self, request_body: Dict[str, Any]) -> List[Document]: + assert self._async_client is not None + search_results = await self._async_client.search(index=self._index, body=request_body) + return self._deserialize_search_hits(search_results["hits"]["hits"]) + + def filter_documents( # noqa: D102 + self, filters: Optional[Dict[str, Any]] = None + ) -> List[Document]: + self._ensure_initialized() + return self._search_documents(self._prepare_filter_search_request(filters)) + + async def filter_documents_async( # noqa: D102 + self, filters: Optional[Dict[str, Any]] = None + ) -> List[Document]: + self._ensure_initialized() + return await self._search_documents_async(self._prepare_filter_search_request(filters)) + + def _prepare_bulk_write_request( + self, documents: List[Document], policy: DuplicatePolicy, is_async: bool + ) -> Dict[str, Any]: + if len(documents) > 0 and not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy == DuplicatePolicy.NONE: + policy = DuplicatePolicy.FAIL + + action = "index" if policy == DuplicatePolicy.OVERWRITE else "create" + return { + "client": self._client if not is_async else self._async_client, + "actions": ( + { + "_op_type": action, + "_id": doc.id, + "_source": doc.to_dict(), + } + for doc in documents + ), + "refresh": "wait_for", + "index": self._index, + "raise_on_error": False, + "max_chunk_bytes": self._max_chunk_bytes, + } + + def _process_bulk_write_errors(self, errors: List[Dict[str, Any]], policy: DuplicatePolicy): + if len(errors) == 0: + return + + duplicate_errors_ids = [] + other_errors = [] + for e in errors: + # OpenSearch might not return a correctly formatted error, in that case we + # treat it as a generic error + if "create" not in e: + other_errors.append(e) + continue + error_type = e["create"]["error"]["type"] + if policy == DuplicatePolicy.FAIL and error_type == "version_conflict_engine_exception": + duplicate_errors_ids.append(e["create"]["_id"]) + elif policy == DuplicatePolicy.SKIP and error_type == "version_conflict_engine_exception": + # when the policy is skip, duplication errors are OK and we should not raise an exception + continue + else: + other_errors.append(e) + + if len(duplicate_errors_ids) > 0: + msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store." + raise DuplicateDocumentError(msg) + + if len(other_errors) > 0: + msg = f"Failed to write documents to OpenSearch. Errors:\n{other_errors}" + raise DocumentStoreError(msg) + + def write_documents( # noqa: D102 + self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE + ) -> int: + self._ensure_initialized() + + bulk_params = self._prepare_bulk_write_request(documents, policy, is_async=False) + documents_written, errors = bulk(**bulk_params) + self._process_bulk_write_errors(errors, policy) + return documents_written + + async def write_documents_async( # noqa: D102 + self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE + ) -> int: + self._ensure_initialized() + + bulk_params = self._prepare_bulk_write_request(documents, policy, is_async=True) + documents_written, errors = await async_bulk(**bulk_params) + self._process_bulk_write_errors(errors, policy) # type:ignore + return documents_written + + def _prepare_bulk_delete_request(self, document_ids: List[str], is_async: bool) -> Dict[str, Any]: + return { + "client": self._client if not is_async else self._async_client, + "actions": ({"_op_type": "delete", "_id": id_} for id_ in document_ids), + "refresh": "wait_for", + "index": self._index, + "raise_on_error": False, + "max_chunk_bytes": self._max_chunk_bytes, + } + + def delete_documents(self, document_ids: List[str]) -> None: # noqa: D102 + self._ensure_initialized() + + bulk(**self._prepare_bulk_delete_request(document_ids, is_async=False)) + + async def delete_documents_async( # noqa: D102 + self, document_ids: List[str] + ) -> None: + self._ensure_initialized() + + await async_bulk(**self._prepare_bulk_delete_request(document_ids, is_async=True)) + + def _render_custom_query(self, custom_query: Any, substitutions: Dict[str, Any]) -> Any: + """ + Recursively replaces the placeholders in the custom_query with the actual values. + + :param custom_query: The custom query to replace the placeholders in. + :param substitutions: The dictionary containing the actual values to replace the placeholders with. + :returns: The custom query with the placeholders replaced. + """ + if isinstance(custom_query, dict): + return {key: self._render_custom_query(value, substitutions) for key, value in custom_query.items()} + elif isinstance(custom_query, list): + return [self._render_custom_query(entry, substitutions) for entry in custom_query] + elif isinstance(custom_query, str): + return substitutions.get(custom_query, custom_query) + + return custom_query + + def _prepare_bm25_search_request( + self, + *, + query: str, + filters: Optional[Dict[str, Any]], + fuzziness: str, + top_k: int, + all_terms_must_match: bool, + custom_query: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + raise_on_invalid_filter_syntax(filters) + + if not query: + body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}} + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + if isinstance(custom_query, dict): + body = self._render_custom_query( + custom_query, + { + "$query": query, + "$filters": normalize_filters(filters), # type:ignore + }, + ) + + else: + operator = "AND" if all_terms_must_match else "OR" + body = { + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "fuzziness": fuzziness, + "type": "most_fields", + "operator": operator, + } + } + ] + } + }, + } + + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k + + # For some applications not returning the embedding can save a lot of bandwidth + # if you don't need this data not retrieving it can be a good idea + if not self._return_embedding: + body["_source"] = {"excludes": ["embedding"]} + + return body + + def _postprocess_bm25_search_results(self, results: List[Document], scale_score: bool): + if not scale_score: + return + + for doc in results: + assert doc.score is not None + doc.score = float(1 / (1 + math.exp(-(doc.score / float(BM25_SCALING_FACTOR))))) + + def _bm25_retrieval( + self, + query: str, + *, + filters: Optional[Dict[str, Any]] = None, + fuzziness: str = "AUTO", + top_k: int = 10, + scale_score: bool = False, + all_terms_must_match: bool = False, + custom_query: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + self._ensure_initialized() + + search_params = self._prepare_bm25_search_request( + query=query, + filters=filters, + fuzziness=fuzziness, + top_k=top_k, + all_terms_must_match=all_terms_must_match, + custom_query=custom_query, + ) + documents = self._search_documents(search_params) + self._postprocess_bm25_search_results(documents, scale_score) + return documents + + async def _bm25_retrieval_async( + self, + query: str, + *, + filters: Optional[Dict[str, Any]] = None, + fuzziness: str = "AUTO", + top_k: int = 10, + scale_score: bool = False, + all_terms_must_match: bool = False, + custom_query: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + self._ensure_initialized() + + search_params = self._prepare_bm25_search_request( + query=query, + filters=filters, + fuzziness=fuzziness, + top_k=top_k, + all_terms_must_match=all_terms_must_match, + custom_query=custom_query, + ) + documents = await self._search_documents_async(search_params) + self._postprocess_bm25_search_results(documents, scale_score) + return documents + + def _prepare_embedding_search_request( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]], + top_k: int, + custom_query: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + raise_on_invalid_filter_syntax(filters) + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + body: Dict[str, Any] + if isinstance(custom_query, dict): + body = self._render_custom_query( + custom_query, + { + "$query_embedding": query_embedding, + "$filters": normalize_filters(filters), # type:ignore + }, + ) + + else: + body = { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": query_embedding, + "k": top_k, + } + } + } + ], + } + }, + } + + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k + + # For some applications not returning the embedding can save a lot of bandwidth + # if you don't need this data not retrieving it can be a good idea + if not self._return_embedding: + body["_source"] = {"excludes": ["embedding"]} + + return body + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + custom_query: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + self._ensure_initialized() + + search_params = self._prepare_embedding_search_request(query_embedding, filters, top_k, custom_query) + return self._search_documents(search_params) + + async def _embedding_retrieval_async( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + custom_query: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + self._ensure_initialized() + + search_params = self._prepare_embedding_search_request(query_embedding, filters, top_k, custom_query) + return await self._search_documents_async(search_params) diff --git a/haystack_experimental/document_stores/types/__init__.py b/haystack_experimental/document_stores/types/__init__.py new file mode 100644 index 00000000..84e7d388 --- /dev/null +++ b/haystack_experimental/document_stores/types/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .protocol import DocumentStore + +__all__ = ["DocumentStore"] diff --git a/haystack_experimental/document_stores/types/protocol.py b/haystack_experimental/document_stores/types/protocol.py new file mode 100644 index 00000000..029756e0 --- /dev/null +++ b/haystack_experimental/document_stores/types/protocol.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Protocol + +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy + + +class DocumentStore(Protocol): + """ + Stores Documents to be used by the components of a Pipeline. + + Classes implementing this protocol often store the documents permanently and allow specialized components to + perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used. + + In order to retrieve documents, consider using a Retriever that supports the DocumentStore implementation that + you're using. + """ + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this store to a dictionary. + """ + pass + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore": + """ + Deserializes the store from a dictionary. + """ + pass + + def count_documents(self) -> int: + """ + Returns the number of documents stored. + """ + pass + + async def count_documents_async(self) -> int: # noqa: D102 + pass + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + Filters are defined as nested dictionaries that can be of two types: + - Comparison + - Logic + + Comparison dictionaries must contain the keys: + + - `field` + - `operator` + - `value` + + Logic dictionaries must contain the keys: + + - `operator` + - `conditions` + + The `conditions` key must be a list of dictionaries, either of type Comparison or Logic. + + The `operator` value in Comparison dictionaries must be one of: + + - `==` + - `!=` + - `>` + - `>=` + - `<` + - `<=` + - `in` + - `not in` + + The `operator` values in Logic dictionaries must be one of: + + - `NOT` + - `OR` + - `AND` + + + A simple filter: + ```python + filters = {"field": "meta.type", "operator": "==", "value": "article"} + ``` + + A more complex filter: + ```python + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.type", "operator": "==", "value": "article"}, + {"field": "meta.date", "operator": ">=", "value": 1420066800}, + {"field": "meta.date", "operator": "<", "value": 1609455600}, + {"field": "meta.rating", "operator": ">=", "value": 3}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, + ], + }, + ], + } + + :param filters: the filters to apply to the document list. + :returns: a list of Documents that match the given filters. + """ + pass + + async def filter_documents_async( # noqa: D102 + self, filters: Optional[Dict[str, Any]] = None + ) -> List[Document]: + pass + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes Documents into the DocumentStore. + + :param documents: a list of Document objects. + :param policy: the policy to apply when a Document with the same id already exists in the DocumentStore. + - `DuplicatePolicy.NONE`: Default policy, behaviour depends on the Document Store. + - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. + :raises DuplicateError: If `policy` is set to `DuplicatePolicy.FAIL` and a Document with the same id already + exists. + :returns: The number of Documents written. + If `DuplicatePolicy.OVERWRITE` is used, this number is always equal to the number of documents in input. + If `DuplicatePolicy.SKIP` is used, this number can be lower than the number of documents in the input list. + """ + pass + + async def write_documents_async( # noqa: D102 + self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE + ) -> int: + pass + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the DocumentStore. + + Fails with `MissingDocumentError` if no document with this id is present in the DocumentStore. + + :param document_ids: the object_ids to delete + """ + pass + + async def delete_documents_async( # noqa: D102 + self, document_ids: List[str] + ) -> None: + pass diff --git a/haystack_experimental/util/__init__.py b/haystack_experimental/util/__init__.py index 032a10bc..179a8c6f 100644 --- a/haystack_experimental/util/__init__.py +++ b/haystack_experimental/util/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.util.asynchronous import is_callable_async_compatible from haystack_experimental.util.auth import serialize_secrets_inplace -__all__ = ["serialize_secrets_inplace"] +__all__ = ["is_callable_async_compatible", "serialize_secrets_inplace"] diff --git a/haystack_experimental/util/asynchronous.py b/haystack_experimental/util/asynchronous.py new file mode 100644 index 00000000..e84ecbd0 --- /dev/null +++ b/haystack_experimental/util/asynchronous.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Callable + + +def is_callable_async_compatible(func: Callable) -> bool: + """ + Returns if the given callable is usable inside a component's `run_async` method. + + :param callable: + The callable to check. + :returns: + True if the callable is compatible, False otherwise. + """ + return inspect.iscoroutinefunction(func) diff --git a/pyproject.toml b/pyproject.toml index 65bd8b56..45324448 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ dependencies = [ "pytest", "pytest-rerunfailures", "pytest-cov", + "pytest-bdd", + "pytest-asyncio", # Linting "pylint", "ruff", @@ -63,6 +65,9 @@ extra-dependencies = [ # Tools support "jsonschema", "ollama-haystack>=1.1.0", + # Async + "opensearch-haystack", + "opensearch-py[async]", # LLMMetadataExtractor dependencies "amazon-bedrock-haystack>=1.0.2", "google-vertex-haystack>=2.0.0", @@ -160,8 +165,8 @@ disable = [ "cyclic-import", "import-outside-toplevel", "deprecated-method", - "too-many-arguments", # sometimes we need to pass more than 5 arguments - "too-many-instance-attributes" # sometimes we need to have a class with more than 7 attributes + "too-many-arguments", # sometimes we need to pass more than 5 arguments + "too-many-instance-attributes", # sometimes we need to have a class with more than 7 attributes ] [tool.pytest.ini_options] @@ -169,9 +174,10 @@ minversion = "6.0" addopts = "--strict-markers" markers = [ "integration: integration tests", - "unstable(reason): Mark tests that are unstable or depend on unreliable services." + "unstable(reason): Mark tests that are unstable or depend on unreliable services.", ] log_cli = true +asyncio_mode = "auto" [tool.mypy] warn_return_any = false diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 7f786ab4..4a835f9f 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -11,7 +11,12 @@ from datetime import datetime from openai import OpenAIError -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageToolCall, +) from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk @@ -20,9 +25,17 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret -from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent -from haystack_experimental.components.generators.chat.openai import OpenAIChatGenerator, _convert_message_to_openai_format, OpenAIChatGeneratorBase - +from haystack_experimental.dataclasses import ( + ChatMessage, + Tool, + ToolCall, + ChatRole, + TextContent, +) +from haystack_experimental.components.generators.chat.openai import ( + OpenAIChatGenerator, + _convert_message_to_openai_format, +) @pytest.fixture @@ -32,6 +45,7 @@ def chat_messages(): ChatMessage.from_user("What's the capital of France"), ] + class MockStream(Stream[ChatCompletionChunk]): def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): client = client or MagicMock() @@ -42,34 +56,47 @@ def __stream__(self) -> Iterator[ChatCompletionChunk]: # Yielding only one ChatCompletionChunk object yield self.mock_chunk + @pytest.fixture def mock_chat_completion_chunk(): """ Mock the OpenAI API completion chunk response and reuse it for tests """ - with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + with patch( + "openai.resources.chat.completions.Completions.create" + ) as mock_chat_completion_create: completion = ChatCompletionChunk( id="foo", model="gpt-4", object="chat.completion.chunk", choices=[ chat_completion_chunk.Choice( - finish_reason="stop", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant") + finish_reason="stop", + logprobs=None, + index=0, + delta=chat_completion_chunk.ChoiceDelta( + content="Hello", role="assistant" + ), ) ], created=int(datetime.now().timestamp()), usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) - mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None) + mock_chat_completion_create.return_value = MockStream( + completion, cast_to=None, response=None, client=None + ) yield mock_chat_completion_create + @pytest.fixture def mock_chat_completion(): """ Mock the OpenAI API completion response and reuse it for tests """ - with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + with patch( + "openai.resources.chat.completions.Completions.create" + ) as mock_chat_completion_create: completion = ChatCompletion( id="foo", model="gpt-4", @@ -79,7 +106,9 @@ def mock_chat_completion(): finish_reason="stop", logprobs=None, index=0, - message=ChatCompletionMessage(content="Hello world!", role="assistant"), + message=ChatCompletionMessage( + content="Hello world!", role="assistant" + ), ) ], created=int(datetime.now().timestamp()), @@ -89,49 +118,66 @@ def mock_chat_completion(): mock_chat_completion_create.return_value = completion yield mock_chat_completion_create + @pytest.fixture def mock_chat_completion_chunk_with_tools(): """ Mock the OpenAI API completion chunk response and reuse it for tests """ - with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + with patch( + "openai.resources.chat.completions.Completions.create" + ) as mock_chat_completion_create: completion = ChatCompletionChunk( id="foo", model="gpt-4", object="chat.completion.chunk", choices=[ chat_completion_chunk.Choice( - finish_reason="tool_calls", logprobs=None, index=0, delta=chat_completion_chunk.ChoiceDelta( + finish_reason="tool_calls", + logprobs=None, + index=0, + delta=chat_completion_chunk.ChoiceDelta( role="assistant", - tool_calls=[chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id="123", type="function", function=chat_completion_chunk.ChoiceDeltaToolCallFunction(name="weather", arguments='{"city": "Paris"}') - )]) + tool_calls=[ + chat_completion_chunk.ChoiceDeltaToolCall( + index=0, + id="123", + type="function", + function=chat_completion_chunk.ChoiceDeltaToolCallFunction( + name="weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), ) ], created=int(datetime.now().timestamp()), usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) - mock_chat_completion_create.return_value = MockStream(completion, cast_to=None, response=None, client=None) + mock_chat_completion_create.return_value = MockStream( + completion, cast_to=None, response=None, client=None + ) yield mock_chat_completion_create + @pytest.fixture def tools(): tool_parameters = { - "type": "object", - "properties": { - "city": {"type": "string"} - }, - "required": ["city"] -} - tool = Tool(name="weather", description="useful to determine the weather in a given location", - parameters=tool_parameters, function=lambda x:x) + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) return [tool] - class TestOpenAIChatGenerator: def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -158,7 +204,12 @@ def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): OpenAIChatGenerator(tools=duplicate_tools) def test_init_with_parameters(self, monkeypatch): - tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=lambda x: x, + ) monkeypatch.setenv("OPENAI_TIMEOUT", "100") monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") @@ -176,7 +227,10 @@ def test_init_with_parameters(self, monkeypatch): assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.generation_kwargs == { + "max_tokens": 10, + "some_test_param": "test-params", + } assert component.client.timeout == 40.0 assert component.client.max_retries == 1 assert component.tools == [tool] @@ -195,7 +249,10 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): assert component.client.api_key == "test-api-key" assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.generation_kwargs == { + "max_tokens": 10, + "some_test_param": "test-params", + } assert component.client.timeout == 100.0 assert component.client.max_retries == 10 @@ -206,7 +263,11 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, "model": "gpt-4o-mini", "organization": None, "streaming_callback": None, @@ -214,11 +275,18 @@ def test_to_dict_default(self, monkeypatch): "generation_kwargs": {}, "tools": None, "tools_strict": False, + "max_retries": None, + "timeout": None, }, } def test_to_dict_with_parameters(self, monkeypatch): - tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) monkeypatch.setenv("ENV_VAR", "test-api-key") component = OpenAIChatGenerator( @@ -227,8 +295,10 @@ def test_to_dict_with_parameters(self, monkeypatch): streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - tools = [tool], + tools=[tool], tools_strict=True, + max_retries=10, + timeout=100.0, ) data = component.to_dict() @@ -239,21 +309,26 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", + "max_retries": 10, + "timeout": 100.0, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - 'tools': [ - { - 'description': 'description', - 'function': 'builtins.print', - 'name': 'name', - 'parameters': { - 'x': { - 'type': 'string', - }, - }, - }, - ], - 'tools_strict': True, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + "tools_strict": True, }, } @@ -269,15 +344,23 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): assert data == { "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, "model": "gpt-4o-mini", "organization": None, "api_base_url": "test-base-url", + "max_retries": None, + "timeout": None, "streaming_callback": "test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, "tools": None, "tools_strict": False, - }, } @@ -286,24 +369,33 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, "model": "gpt-4o-mini", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - 'tools': [ - { - 'description': 'description', - 'function': 'builtins.print', - 'name': 'name', - 'parameters': { - 'x': { - 'type': 'string', - }, - }, - }, - ], - 'tools_strict': True, + "max_retries": 10, + "timeout": 100.0, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + "tools_strict": True, }, } component = OpenAIChatGenerator.from_dict(data) @@ -312,22 +404,41 @@ def test_from_dict(self, monkeypatch): assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.generation_kwargs == { + "max_tokens": 10, + "some_test_param": "test-params", + } assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") - assert component.tools == [Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)] + assert component.tools == [ + Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) + ] assert component.tools_strict + assert component.client.timeout == 100.0 + assert component.client.max_retries == 10 def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, + "model": "gpt-4", "organization": None, "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, }, } with pytest.raises(ValueError): @@ -346,7 +457,8 @@ def test_run(self, chat_messages, mock_chat_completion): def test_run_with_params(self, chat_messages, mock_chat_completion): component = OpenAIChatGenerator( - api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} + api_key=Secret.from_token("test-api-key"), + generation_kwargs={"max_tokens": 10, "temperature": 0.5}, ) response = component.run(chat_messages) @@ -370,7 +482,8 @@ def streaming_callback(chunk: StreamingChunk) -> None: streaming_callback_called = True component = OpenAIChatGenerator( - api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback + api_key=Secret.from_token("test-api-key"), + streaming_callback=streaming_callback, ) response = component.run(chat_messages) @@ -385,7 +498,9 @@ def streaming_callback(chunk: StreamingChunk) -> None: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk - def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk): + def test_run_with_streaming_callback_in_run_method( + self, chat_messages, mock_chat_completion_chunk + ): streaming_callback_called = False def streaming_callback(chunk: StreamingChunk) -> None: @@ -411,13 +526,17 @@ def test_check_abnormal_completions(self, caplog): component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) messages = [ ChatMessage.from_assistant( - "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} + "", + meta={ + "finish_reason": "content_filter" if i % 2 == 0 else "length", + "index": i, + }, ) for i, _ in enumerate(range(4)) ] for m in messages: - component._check_finish_reason(m) + component._check_finish_reason(m.meta) # check truncation warning message_template = ( @@ -445,7 +564,23 @@ def test_live_run(self): assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "gpt-4o-mini" in message.meta["model"] + assert "gpt-4o" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + chat_messages = [ChatMessage.from_user("What's the capital of France")] + component = OpenAIChatGenerator(generation_kwargs={"n": 1}) + results = await component.run_async(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @pytest.mark.skipif( @@ -458,6 +593,17 @@ def test_live_run_wrong_model(self, chat_messages): with pytest.raises(OpenAIError): component.run(chat_messages) + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_wrong_model_async(self, chat_messages): + component = OpenAIChatGenerator(model="something-obviously-wrong") + with pytest.raises(OpenAIError): + await component.run_async(chat_messages) + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", @@ -475,48 +621,149 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = OpenAIChatGenerator(streaming_callback=callback) - results = component.run([ChatMessage.from_user("What's the capital of France?")]) + results = component.run( + [ChatMessage.from_user("What's the capital of France?")] + ) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "gpt-4o-mini" in message.meta["model"] + assert "gpt-4o" in message.meta["model"] assert message.meta["finish_reason"] == "stop" assert callback.counter > 1 assert "Paris" in callback.responses + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_streaming_async(self): + counter = 0 + responses = "" + async def callback(chunk: StreamingChunk): + nonlocal counter + nonlocal responses + counter += 1 + responses += chunk.content if chunk.content else "" + component = OpenAIChatGenerator(streaming_callback=callback) + results = await component.run_async( + [ChatMessage.from_user("What's the capital of France?")] + ) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "gpt-4o" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + assert counter > 1 + assert "Paris" in responses + + @pytest.mark.asyncio + async def test_streaming_callback_compatibility(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + + async def async_callback(chunk: StreamingChunk): + pass + + def sync_callback(chunk: StreamingChunk): + pass + + with pytest.raises(ValueError, match="init callback must be async compatible"): + gen = OpenAIChatGenerator(streaming_callback=sync_callback) + await gen.run_async([]) + + with pytest.raises( + ValueError, match="runtime callback must be async compatible" + ): + gen = OpenAIChatGenerator(streaming_callback=async_callback) + await gen.run_async([], streaming_callback=sync_callback) + + await gen.run_async([]) + + with pytest.raises(ValueError, match="init callback cannot be a coroutine"): + gen = OpenAIChatGenerator(streaming_callback=async_callback) + gen.run([]) + + with pytest.raises(ValueError, match="runtime callback cannot be a coroutine"): + gen = OpenAIChatGenerator(streaming_callback=sync_callback) + gen.run([], streaming_callback=async_callback) + + gen.run([]) def test_convert_message_to_openai_format(self): message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} + assert _convert_message_to_openai_format(message) == { + "role": "system", + "content": "You are good assistant", + } message = ChatMessage.from_user("I have a question") - assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} + assert _convert_message_to_openai_format(message) == { + "role": "user", + "content": "I have a question", + } - message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) - assert _convert_message_to_openai_format(message) == {"role": "assistant", "content": "I have an answer"} + message = ChatMessage.from_assistant( + text="I have an answer", meta={"finish_reason": "stop"} + ) + assert _convert_message_to_openai_format(message) == { + "role": "assistant", + "content": "I have an answer", + } - message = ChatMessage.from_assistant(tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})]) - assert _convert_message_to_openai_format(message) == {"role": "assistant", "tool_calls": [{"id": "123", "type": "function", "function": {"name": "weather", "arguments": '{"city": "Paris"}'}}]} + message = ChatMessage.from_assistant( + tool_calls=[ + ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ] + ) + assert _convert_message_to_openai_format(message) == { + "role": "assistant", + "tool_calls": [ + { + "id": "123", + "type": "function", + "function": {"name": "weather", "arguments": '{"city": "Paris"}'}, + } + ], + } - tool_result=json.dumps({"weather": "sunny", "temperature": "25"}) - message = ChatMessage.from_tool(tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})) - assert _convert_message_to_openai_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, + origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), + ) + assert _convert_message_to_openai_format(message) == { + "role": "tool", + "content": tool_result, + "tool_call_id": "123", + } def test_convert_message_to_openai_invalid(self): message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) with pytest.raises(ValueError): _convert_message_to_openai_format(message) - message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")]) + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[ + TextContent(text="I have an answer"), + TextContent(text="I have another answer"), + ], + ) with pytest.raises(ValueError): _convert_message_to_openai_format(message) - tool_call_null_id = ToolCall(id=None, tool_name="weather", arguments={"city": "Paris"}) + tool_call_null_id = ToolCall( + id=None, tool_name="weather", arguments={"city": "Paris"} + ) message = ChatMessage.from_assistant(tool_calls=[tool_call_null_id]) with pytest.raises(ValueError): _convert_message_to_openai_format(message) @@ -525,10 +772,11 @@ def test_convert_message_to_openai_invalid(self): with pytest.raises(ValueError): _convert_message_to_openai_format(message) - def test_run_with_tools(self, tools): - with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + with patch( + "openai.resources.chat.completions.Completions.create" + ) as mock_chat_completion_create: completion = ChatCompletion( id="foo", model="gpt-4", @@ -538,20 +786,36 @@ def test_run_with_tools(self, tools): finish_reason="tool_calls", logprobs=None, index=0, - message=ChatCompletionMessage(role="assistant", - tool_calls=[ChatCompletionMessageToolCall( - id="123", type="function", function=Function(name="weather", arguments='{"city": "Paris"}'))]) + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="123", + type="function", + function=Function( + name="weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), ) ], created=int(datetime.now().timestamp()), - usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + usage={ + "prompt_tokens": 57, + "completion_tokens": 40, + "total_tokens": 97, + }, ) mock_chat_completion_create.return_value = completion - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) - response = component.run([ChatMessage.from_user("What's the weather like in Paris?")]) - + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), tools=tools + ) + response = component.run( + [ChatMessage.from_user("What's the weather like in Paris?")] + ) assert len(response["replies"]) == 1 message = response["replies"][0] @@ -566,7 +830,9 @@ def test_run_with_tools(self, tools): assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" - def test_run_with_tools_streaming(self, mock_chat_completion_chunk_with_tools, tools): + def test_run_with_tools_streaming( + self, mock_chat_completion_chunk_with_tools, tools + ): streaming_callback_called = False @@ -575,7 +841,8 @@ def streaming_callback(chunk: StreamingChunk) -> None: streaming_callback_called = True component = OpenAIChatGenerator( - api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback + api_key=Secret.from_token("test-api-key"), + streaming_callback=streaming_callback, ) chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] response = component.run(chat_messages, tools=tools) @@ -602,7 +869,9 @@ def streaming_callback(chunk: StreamingChunk) -> None: def test_invalid_tool_call_json(self, tools, caplog): caplog.set_level(logging.WARNING) - with patch("openai.resources.chat.completions.Completions.create") as mock_create: + with patch( + "openai.resources.chat.completions.Completions.create" + ) as mock_create: mock_create.return_value = ChatCompletion( id="test", model="gpt-4o-mini", @@ -614,22 +883,39 @@ def test_invalid_tool_call_json(self, tools, caplog): message=ChatCompletionMessage( role="assistant", tool_calls=[ - ChatCompletionMessageToolCall(id="1", type="function", function=Function(name="weather", arguments='"invalid": "json"')), - ] - ) + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="weather", arguments='"invalid": "json"' + ), + ), + ], + ), ) ], created=1234567890, - usage={"prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80} + usage={ + "prompt_tokens": 50, + "completion_tokens": 30, + "total_tokens": 80, + }, ) - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"), tools=tools) - response = component.run([ChatMessage.from_user("What's the weather in Paris?")]) + component = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), tools=tools + ) + response = component.run( + [ChatMessage.from_user("What's the weather in Paris?")] + ) assert len(response["replies"]) == 1 message = response["replies"][0] assert len(message.tool_calls) == 0 - assert "OpenAI returned a malformed JSON string for tool call arguments" in caplog.text + assert ( + "OpenAI returned a malformed JSON string for tool call arguments" + in caplog.text + ) @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), @@ -652,3 +938,24 @@ def test_live_run_with_tools(self, tools): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_with_tools_async(self, tools): + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = OpenAIChatGenerator(tools=tools) + results = await component.run_async(chat_messages) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" diff --git a/test/components/retrievers/opensearch/__init__.py b/test/components/retrievers/opensearch/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/retrievers/opensearch/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/retrievers/opensearch/test_bm25_retriever.py b/test/components/retrievers/opensearch/test_bm25_retriever.py new file mode 100644 index 00000000..e041318d --- /dev/null +++ b/test/components/retrievers/opensearch/test_bm25_retriever.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy + +from haystack_experimental.components.retrievers.opensearch import ( + OpenSearchBM25Retriever, +) +from haystack_experimental.document_stores.opensearch import OpenSearchDocumentStore +from haystack_experimental.document_stores.opensearch.document_store import ( + DEFAULT_MAX_CHUNK_BYTES, +) + + +def test_init_default(): + mock_store = Mock(spec=OpenSearchDocumentStore) + retriever = OpenSearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert not retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchBM25Retriever( + document_store=mock_store, filter_policy="replace" + ) + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_to_dict(_mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some fake host") + retriever = OpenSearchBM25Retriever( + document_store=document_store, custom_query={"some": "custom query"} + ) + res = retriever.to_dict() + assert res == { + "type": "haystack_experimental.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "embedding_dim": 768, + "hosts": "some fake host", + "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": {"type": "keyword"}, + "match_mapping_type": "string", + } + } + ], + "properties": { + "content": {"type": "text"}, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "fuzziness": "AUTO", + "top_k": 10, + "scale_score": False, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": True, + }, + } + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_from_dict(_mock_opensearch_client): + data = { + "type": "haystack_experimental.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "fuzziness": "AUTO", + "top_k": 10, + "scale_score": True, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchBM25Retriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._fuzziness == "AUTO" + assert retriever._top_k == 10 + assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._custom_query == {"some": "custom query"} + assert retriever._raise_on_failure is False + + +def test_run(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="some query") + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={}, + fuzziness="AUTO", + top_k=10, + scale_score=False, + all_terms_must_match=False, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.asyncio +async def test_run_async(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever(document_store=mock_store) + res = await retriever.run_async(query="some query") + mock_store._bm25_retrieval_async.assert_called_once_with( + query="some query", + filters={}, + fuzziness="AUTO", + top_k=10, + scale_score=False, + all_terms_must_match=False, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + custom_query={"some": "custom query"}, + ) + res = retriever.run(query="some query") + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={"from": "init"}, + fuzziness="1", + top_k=11, + scale_score=True, + all_terms_must_match=True, + custom_query={"some": "custom query"}, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.asyncio +async def test_run_init_params_async(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + custom_query={"some": "custom query"}, + ) + res = await retriever.run_async(query="some query") + mock_store._bm25_retrieval_async.assert_called_once_with( + query="some query", + filters={"from": "init"}, + fuzziness="1", + top_k=11, + scale_score=True, + all_terms_must_match=True, + custom_query={"some": "custom query"}, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + ) + res = retriever.run( + query="some query", + filters={"from": "run"}, + all_terms_must_match=False, + scale_score=False, + top_k=9, + fuzziness="2", + ) + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={"from": "run"}, + fuzziness="2", + top_k=9, + scale_score=False, + all_terms_must_match=False, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.asyncio +async def test_run_time_params_async(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval_async.return_value = [Document(content="Test doc")] + retriever = OpenSearchBM25Retriever( + document_store=mock_store, + filters={"from": "init"}, + all_terms_must_match=True, + scale_score=True, + top_k=11, + fuzziness="1", + ) + res = await retriever.run_async( + query="some query", + filters={"from": "run"}, + all_terms_must_match=False, + scale_score=False, + top_k=9, + fuzziness="2", + ) + mock_store._bm25_retrieval_async.assert_called_once_with( + query="some query", + filters={"from": "run"}, + fuzziness="2", + top_k=9, + scale_score=False, + all_terms_must_match=False, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_ignore_errors(caplog): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.side_effect = Exception("Some error") + retriever = OpenSearchBM25Retriever( + document_store=mock_store, raise_on_failure=False + ) + res = retriever.run(query="some query") + assert len(res) == 1 + assert res["documents"] == [] + assert "Some error" in caplog.text diff --git a/test/components/retrievers/opensearch/test_embedding_retriever.py b/test/components/retrievers/opensearch/test_embedding_retriever.py new file mode 100644 index 00000000..c670cb15 --- /dev/null +++ b/test/components/retrievers/opensearch/test_embedding_retriever.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy + +from haystack_experimental.components.retrievers.opensearch import ( + OpenSearchEmbeddingRetriever, +) +from haystack_experimental.document_stores.opensearch import OpenSearchDocumentStore +from haystack_experimental.document_stores.opensearch.document_store import ( + DEFAULT_MAX_CHUNK_BYTES, +) + + +def test_init_default(): + mock_store = Mock(spec=OpenSearchDocumentStore) + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filter_policy="replace" + ) + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_to_dict(_mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some fake host") + retriever = OpenSearchEmbeddingRetriever( + document_store=document_store, custom_query={"some": "custom query"} + ) + res = retriever.to_dict() + type_s = "haystack_experimental.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever" + assert res == { + "type": type_s, + "init_parameters": { + "document_store": { + "init_parameters": { + "embedding_dim": 768, + "hosts": "some fake host", + "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": { + "type": "keyword", + }, + "match_mapping_type": "string", + }, + }, + ], + "properties": { + "content": { + "type": "text", + }, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": { + "index.knn": True, + }, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": True, + }, + } + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_from_dict(_mock_opensearch_client): + type_s = "haystack_experimental.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever" + data = { + "type": type_s, + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchEmbeddingRetriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._custom_query == {"some": "custom query"} + assert retriever._raise_on_failure is False + assert retriever._filter_policy == FilterPolicy.REPLACE + + # For backwards compatibility with older versions of the retriever without a filter policy + data = { + "type": type_s, + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchEmbeddingRetriever.from_dict(data) + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={}, + top_k=10, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.asyncio +async def test_run_async(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval_async.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store) + res = await retriever.run_async(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval_async.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={}, + top_k=10, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.asyncio +async def test_run_async_init_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval_async.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + ) + res = await retriever.run_async(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval_async.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filters={"from": "init"}, top_k=11 + ) + res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "run"}, + top_k=9, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.asyncio +async def test_run_async_time_params(): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval_async.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2]) + ] + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filters={"from": "init"}, top_k=11 + ) + res = await retriever.run_async( + query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9 + ) + mock_store._embedding_retrieval_async.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters={"from": "run"}, + top_k=9, + custom_query=None, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_ignore_errors(caplog): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.side_effect = Exception("Some error") + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, raise_on_failure=False + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + assert len(res) == 1 + assert res["documents"] == [] + assert "Some error" in caplog.text diff --git a/test/components/splitters/test_hierarchical_doc_splitter.py b/test/components/splitters/test_hierarchical_doc_splitter.py index 85c681ee..29a788bf 100644 --- a/test/components/splitters/test_hierarchical_doc_splitter.py +++ b/test/components/splitters/test_hierarchical_doc_splitter.py @@ -18,23 +18,35 @@ def test_init_with_default_params(self): assert builder.split_by == "word" def test_init_with_custom_params(self): - builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word") + builder = HierarchicalDocumentSplitter( + block_sizes={100, 200, 300}, split_overlap=25, split_by="word" + ) assert builder.block_sizes == [300, 200, 100] assert builder.split_overlap == 25 assert builder.split_by == "word" def test_to_dict(self): - builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word") + builder = HierarchicalDocumentSplitter( + block_sizes={100, 200, 300}, split_overlap=25, split_by="word" + ) expected = builder.to_dict() assert expected == { "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", - "init_parameters": {"block_sizes": [300, 200, 100], "split_overlap": 25, "split_by": "word"}, + "init_parameters": { + "block_sizes": [300, 200, 100], + "split_overlap": 25, + "split_by": "word", + }, } def test_from_dict(self): data = { "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", - "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + "init_parameters": { + "block_sizes": [10, 5, 2], + "split_overlap": 0, + "split_by": "word", + }, } builder = HierarchicalDocumentSplitter.from_dict(data) @@ -43,7 +55,9 @@ def test_from_dict(self): assert builder.split_by == "word" def test_run(self): - builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + builder = HierarchicalDocumentSplitter( + block_sizes={10, 5, 2}, split_overlap=0, split_by="word" + ) text = "one two three four five six seven eight nine ten" doc = Document(content=text) output = builder.run([doc]) @@ -94,16 +108,30 @@ def test_to_dict_in_pipeline(self): hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}) doc_store = InMemoryDocumentStore() doc_writer = DocumentWriter(document_store=doc_store) - pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component( + name="hierarchical_doc_splitter", instance=hierarchical_doc_builder + ) pipeline.add_component(name="doc_writer", instance=doc_writer) pipeline.connect("hierarchical_doc_splitter", "doc_writer") expected = pipeline.to_dict() - assert expected.keys() == {"metadata", "max_runs_per_component", "components", "connections"} - assert expected["components"].keys() == {"hierarchical_doc_splitter", "doc_writer"} + assert expected.keys() == { + "metadata", + "max_runs_per_component", + "components", + "connections", + } + assert expected["components"].keys() == { + "hierarchical_doc_splitter", + "doc_writer", + } assert expected["components"]["hierarchical_doc_splitter"] == { "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", - "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + "init_parameters": { + "block_sizes": [10, 5, 2], + "split_overlap": 0, + "split_by": "word", + }, } def test_from_dict_in_pipeline(self): @@ -113,7 +141,11 @@ def test_from_dict_in_pipeline(self): "components": { "hierarchical_doc_splitter": { "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", - "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + "init_parameters": { + "block_sizes": [10, 5, 2], + "split_overlap": 0, + "split_by": "word", + }, }, "doc_writer": { "type": "haystack.components.writers.document_writer.DocumentWriter", @@ -132,7 +164,12 @@ def test_from_dict_in_pipeline(self): }, }, }, - "connections": [{"sender": "hierarchical_doc_splitter.documents", "receiver": "doc_writer.documents"}], + "connections": [ + { + "sender": "hierarchical_doc_splitter.documents", + "receiver": "doc_writer.documents", + } + ], } assert Pipeline.from_dict(data) @@ -140,11 +177,15 @@ def test_from_dict_in_pipeline(self): @pytest.mark.integration def test_example_in_pipeline(self): pipeline = Pipeline() - hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + hierarchical_doc_builder = HierarchicalDocumentSplitter( + block_sizes={10, 5, 2}, split_overlap=0, split_by="word" + ) doc_store = InMemoryDocumentStore() doc_writer = DocumentWriter(document_store=doc_store) - pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component( + name="hierarchical_doc_splitter", instance=hierarchical_doc_builder + ) pipeline.add_component(name="doc_writer", instance=doc_writer) pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer") @@ -157,11 +198,15 @@ def test_example_in_pipeline(self): def test_serialization_deserialization_pipeline(self): pipeline = Pipeline() - hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + hierarchical_doc_builder = HierarchicalDocumentSplitter( + block_sizes={10, 5, 2}, split_overlap=0, split_by="word" + ) doc_store = InMemoryDocumentStore() doc_writer = DocumentWriter(document_store=doc_store) - pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component( + name="hierarchical_doc_splitter", instance=hierarchical_doc_builder + ) pipeline.add_component(name="doc_writer", instance=doc_writer) pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer") pipeline_dict = pipeline.to_dict() diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 80d44251..d2207cea 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -4,9 +4,18 @@ from haystack import Pipeline -from haystack_experimental.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole +from haystack_experimental.dataclasses import ( + ChatMessage, + ToolCall, + ToolCallResult, + ChatRole, +) from haystack_experimental.dataclasses.tool import Tool, ToolInvocationError -from haystack_experimental.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError +from haystack_experimental.components.tools.tool_invoker import ( + ToolInvoker, + ToolNotFoundException, + StringConversionError, +) from haystack_experimental.components.generators.chat import OpenAIChatGenerator @@ -16,26 +25,28 @@ def weather_function(location): "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, } - return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) + return weather_info.get( + location, {"weather": "unknown", "temperature": 0, "unit": "celsius"} + ) weather_parameters = { "type": "object", - "properties": { - "location": {"type": "string"} - }, - "required": ["location"] + "properties": {"location": {"type": "string"}}, + "required": ["location"], } + @pytest.fixture def weather_tool(): return Tool( name="weather_tool", description="Provides weather information for a given location.", parameters=weather_parameters, - function=weather_function + function=weather_function, ) + @pytest.fixture def faulty_tool(): def faulty_tool_func(location): @@ -44,23 +55,30 @@ def faulty_tool_func(location): faulty_tool_parameters = { "type": "object", "properties": {"location": {"type": "string"}}, - "required": ["location"] + "required": ["location"], } return Tool( name="faulty_tool", description="A tool that always fails when invoked.", parameters=faulty_tool_parameters, - function=faulty_tool_func + function=faulty_tool_func, ) + @pytest.fixture def invoker(weather_tool): - return ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False) + return ToolInvoker( + tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False + ) + @pytest.fixture def faulty_invoker(faulty_tool): - return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False) + return ToolInvoker( + tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False + ) + class TestToolInvoker: @@ -68,7 +86,7 @@ def test_init(self, weather_tool): invoker = ToolInvoker(tools=[weather_tool]) assert invoker.tools == [weather_tool] - assert invoker._tools_with_names == {'weather_tool': weather_tool} + assert invoker._tools_with_names == {"weather_tool": weather_tool} assert invoker.raise_on_failure assert not invoker.convert_result_to_json_string @@ -87,13 +105,8 @@ def test_init_fails_with_duplicate_tool_names(self, weather_tool, faulty_tool): def test_run(self, invoker): - tool_call = ToolCall( - tool_name="weather_tool", - arguments={"location": "Berlin"} - ) - message = ChatMessage.from_assistant( - tool_calls=[tool_call] - ) + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) result = invoker.run(messages=[message]) assert "tool_messages" in result @@ -107,7 +120,9 @@ def test_run(self, invoker): tool_call_result = tool_message.tool_call_result assert isinstance(tool_call_result, ToolCallResult) - assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}) + assert tool_call_result.result == str( + {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"} + ) assert tool_call_result.origin == tool_call assert not tool_call_result.error @@ -134,7 +149,7 @@ def test_run_multiple_tool_calls(self, invoker): assert "tool_messages" in result assert len(result["tool_messages"]) == 3 - for i,tool_message in enumerate(result["tool_messages"]): + for i, tool_message in enumerate(result["tool_messages"]): assert isinstance(tool_message, ChatMessage) assert tool_message.is_from(ChatRole.TOOL) @@ -145,16 +160,11 @@ def test_run_multiple_tool_calls(self, invoker): assert not tool_call_result.error assert tool_call_result.origin == tool_calls[i] - - def test_tool_not_found_error(self, invoker): tool_call = ToolCall( - tool_name="non_existent_tool", - arguments={"location": "Berlin"} - ) - tool_call_message = ChatMessage.from_assistant( - tool_calls=[tool_call] + tool_name="non_existent_tool", arguments={"location": "Berlin"} ) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) with pytest.raises(ToolNotFoundException): invoker.run(messages=[tool_call_message]) @@ -163,12 +173,9 @@ def test_tool_not_found_does_not_raise_exception(self, invoker): invoker.raise_on_failure = False tool_call = ToolCall( - tool_name="non_existent_tool", - arguments={"location": "Berlin"} - ) - tool_call_message = ChatMessage.from_assistant( - tool_calls=[tool_call] + tool_name="non_existent_tool", arguments={"location": "Berlin"} ) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) result = invoker.run(messages=[tool_call_message]) tool_message = result["tool_messages"][0] @@ -177,13 +184,8 @@ def test_tool_not_found_does_not_raise_exception(self, invoker): assert "not found" in tool_message.tool_call_results[0].result def test_tool_invocation_error(self, faulty_invoker): - tool_call = ToolCall( - tool_name="faulty_tool", - arguments={"location": "Berlin"} - ) - tool_call_message = ChatMessage.from_assistant( - tool_calls=[tool_call] - ) + tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) with pytest.raises(ToolInvocationError): faulty_invoker.run(messages=[tool_call_message]) @@ -191,13 +193,8 @@ def test_tool_invocation_error(self, faulty_invoker): def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker): faulty_invoker.raise_on_failure = False - tool_call = ToolCall( - tool_name="faulty_tool", - arguments={"location": "Berlin"} - ) - tool_call_message = ChatMessage.from_assistant( - tool_calls=[tool_call] - ) + tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) result = faulty_invoker.run(messages=[tool_call_message]) tool_message = result["tool_messages"][0] @@ -207,31 +204,28 @@ def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker): def test_string_conversion_error(self, invoker): invoker.convert_result_to_json_string = True - tool_call = ToolCall( - tool_name="weather_tool", - arguments={"location": "Berlin"} - ) + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) tool_result = datetime.datetime.now() with pytest.raises(StringConversionError): - invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call) + invoker._prepare_tool_result_message( + result=tool_result, tool_call=tool_call + ) def test_string_conversion_error_does_not_raise_exception(self, invoker): invoker.convert_result_to_json_string = True invoker.raise_on_failure = False - tool_call = ToolCall( - tool_name="weather_tool", - arguments={"location": "Berlin"} - ) + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) tool_result = datetime.datetime.now() - tool_message = invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call) + tool_message = invoker._prepare_tool_result_message( + result=tool_result, tool_call=tool_call + ) assert tool_message.tool_call_results[0].error assert "Failed to convert" in tool_message.tool_call_results[0].result - def test_to_dict(self, invoker, weather_tool): data = invoker.to_dict() assert data == { @@ -239,7 +233,7 @@ def test_to_dict(self, invoker, weather_tool): "init_parameters": { "tools": [weather_tool.to_dict()], "raise_on_failure": True, - "convert_result_to_json_string": False + "convert_result_to_json_string": False, }, } @@ -249,12 +243,12 @@ def test_from_dict(self, weather_tool): "init_parameters": { "tools": [weather_tool.to_dict()], "raise_on_failure": True, - "convert_result_to_json_string": False + "convert_result_to_json_string": False, }, } invoker = ToolInvoker.from_dict(data) assert invoker.tools == [weather_tool] - assert invoker._tools_with_names == {'weather_tool': weather_tool} + assert invoker._tools_with_names == {"weather_tool": weather_tool} assert invoker.raise_on_failure assert not invoker.convert_result_to_json_string @@ -268,57 +262,57 @@ def test_serde_in_pipeline(self, invoker, monkeypatch): pipeline_dict = pipeline.to_dict() assert pipeline_dict == { - 'metadata': {}, - 'max_runs_per_component': 100, - 'components': { - 'invoker': { - 'type': 'haystack_experimental.components.tools.tool_invoker.ToolInvoker', - 'init_parameters': { - 'tools': [ + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "invoker": { + "type": "haystack_experimental.components.tools.tool_invoker.ToolInvoker", + "init_parameters": { + "tools": [ { - 'name': 'weather_tool', - 'description': 'Provides weather information for a given location.', - 'parameters': { - 'type': 'object', - 'properties': { - 'location': {'type': 'string'} - }, - 'required': ['location'] + "name": "weather_tool", + "description": "Provides weather information for a given location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], }, - 'function': 'test.components.tools.test_tool_invoker.weather_function' + "function": "test.components.tools.test_tool_invoker.weather_function", } ], - 'raise_on_failure': True, - 'convert_result_to_json_string': False - } + "raise_on_failure": True, + "convert_result_to_json_string": False, + }, }, - 'chatgenerator': { - 'type': 'haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator', - 'init_parameters': { - 'model': 'gpt-4o-mini', - 'streaming_callback': None, - 'api_base_url': None, - 'organization': None, - 'generation_kwargs': {}, - 'api_key': { - 'type': 'env_var', - 'env_vars': ['OPENAI_API_KEY'], - 'strict': True + "chatgenerator": { + "type": "haystack_experimental.components.generators.chat.openai.OpenAIChatGenerator", + "init_parameters": { + "model": "gpt-4o-mini", + "streaming_callback": None, + "api_base_url": None, + "organization": None, + "generation_kwargs": {}, + "max_retries": None, + "timeout": None, + "api_key": { + "type": "env_var", + "env_vars": ["OPENAI_API_KEY"], + "strict": True, }, - 'tools': None, - 'tools_strict': False - } - } + "tools": None, + "tools_strict": False, + }, + }, }, - 'connections': [ + "connections": [ { - 'sender': 'invoker.tool_messages', - 'receiver': 'chatgenerator.messages' + "sender": "invoker.tool_messages", + "receiver": "chatgenerator.messages", } - ] + ], } pipeline_yaml = pipeline.dumps() new_pipeline = Pipeline.loads(pipeline_yaml) - assert new_pipeline==pipeline + assert new_pipeline == pipeline diff --git a/test/components/writers/test_document_writer.py b/test/components/writers/test_document_writer.py new file mode 100644 index 00000000..5e6eac55 --- /dev/null +++ b/test/components/writers/test_document_writer.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from haystack import Document +from haystack.testing.factory import document_store_class +from haystack_experimental.components.writers.document_writer import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore + + +class TestDocumentWriter: + @pytest.mark.asyncio + async def test_run_invalid_docstore(self): + document_store = document_store_class("MockedDocumentStore") + + writer = DocumentWriter(document_store) + documents = [ + Document(content="This is the text of a document."), + Document(content="This is the text of another document."), + ] + + with pytest.raises(TypeError, match="does not provide async support"): + result = await writer.run_async(documents=documents) + + @pytest.mark.asyncio + async def test_run(self): + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store) + documents = [ + Document(content="This is the text of a document."), + Document(content="This is the text of another document."), + ] + + result = await writer.run_async(documents=documents) + assert result["documents_written"] == 2 + + @pytest.mark.asyncio + async def test_run_skip_policy(self): + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store, policy=DuplicatePolicy.SKIP) + documents = [ + Document(content="This is the text of a document."), + Document(content="This is the text of another document."), + ] + + result = await writer.run_async(documents=documents) + assert result["documents_written"] == 2 + + result = await writer.run_async(documents=documents) + assert result["documents_written"] == 0 diff --git a/test/core/__init__.py b/test/core/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/core/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/core/pipeline/__init__.py b/test/core/pipeline/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/core/pipeline/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/core/pipeline/features/README.md b/test/core/pipeline/features/README.md new file mode 100644 index 00000000..7a5abc4d --- /dev/null +++ b/test/core/pipeline/features/README.md @@ -0,0 +1,136 @@ +# `Pipeline.run()` behavioural tests + +This module contains all behavioural tests for `Pipeline.run()`. + +`pipeline_run.feature` contains the definition of the tests using a subset of the [Gherkin language](https://cucumber.io/docs/gherkin/). It's not the full language because we're using `pytest-bdd` and it doesn't implement it in full, but it's good enough for our use case. For more info see the [project `README.md`](https://github.com/pytest-dev/pytest-bdd). + +There are two cases covered by these tests: + +1. `Pipeline.run()` returns some output +2. `Pipeline.run()` raises an exception + +### Correct Pipeline + +In the first case to add a new test you need add a new entry in the `Examples` of the `Running a correct Pipeline` scenario outline and create the corresponding step that creates the `Pipeline` you need to test. + +For example to add a test for a linear `Pipeline` I add a new `that is linear` kind in `pipeline_run.feature`. + +```gherkin + Scenario Outline: Running a correct Pipeline + Given a pipeline + When I run the Pipeline + Then it should return the expected result + + Examples: + | kind | + | that has no components | + | that is linear | +``` + +Then define a new `pipeline_that_is_linear` function in `test_run.py`. +The function must be decorated with `@given` and return a tuple containing the `Pipeline` instance and a list of `PipelineRunData` instances. +`PipelineRunData` is a dataclass that stores all the information necessary to verify the `Pipeline` ran as expected. +The `@given` arguments must be the full step name, `"a pipeline that is linear"` in this case, and `target_fixture` must be set to `"pipeline_data"`. + +```python +@given("a pipeline that is linear", target_fixture="pipeline_data") +def pipeline_that_is_linear(): + pipeline = Pipeline() + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", Double()) + pipeline.connect("first_addition", "double") + pipeline.connect("double", "second_addition") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + expected_outputs={"second_addition": {"result": 7}}, + expected_run_order=["first_addition", "double", "second_addition"], + ) + ], + ) +``` + +Some kinds of `Pipeline`s require multiple runs to verify they work correctly, for example those with multiple branches. +For this reason we can return a list of `PipelineRunData`, we'll run the `Pipeline` for each instance. +For example, we could test two different runs of the same pipeline like this: + +```python +@given("a pipeline that is linear", target_fixture="pipeline_data") +def pipeline_that_is_linear(): + pipeline = Pipeline() + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", Double()) + pipeline.connect("first_addition", "double") + pipeline.connect("double", "second_addition") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + include_outputs_from=set(), + expected_outputs={"second_addition": {"result": 7}}, + expected_run_order=["first_addition", "double", "second_addition"], + ), + PipelineRunData( + inputs={"first_addition": {"value": 100}}, + include_outputs_from=set(), + expected_outputs={"first_addition": {"value": 206}}, + expected_run_order=["first_addition", "double", "second_addition"], + ), + ], + ) +``` + +### Bad Pipeline + +The second case is similar to the first one. +In this case we test that a `Pipeline` with an infinite loop raises `PipelineMaxLoops`. + +```gherkin + Scenario Outline: Running a bad Pipeline + Given a pipeline + When I run the Pipeline + Then it must have raised + + Examples: + | kind | exception | + | that has an infinite loop | PipelineMaxLoops | +``` + +In a similar way as first case we need to defined a new `pipeline_that_has_an_infinite_loop` function in `test_run.py`, with some small differences. +The only difference from the first case is the last value returned by the function, we just omit the expected outputs and the expected run order. + +```python +@given("a pipeline that has an infinite loop", target_fixture="pipeline_data") +def pipeline_that_has_an_infinite_loop(): + def custom_init(self): + component.set_input_type(self, "x", int) + component.set_input_type(self, "y", int, 1) + component.set_output_types(self, a=int, b=int) + + FakeComponent = component_class("FakeComponent", output={"a": 1, "b": 1}, extra_fields={"__init__": custom_init}) + pipe = Pipeline(max_loops_allowed=1) + pipe.add_component("first", FakeComponent()) + pipe.add_component("second", FakeComponent()) + pipe.connect("first.a", "second.x") + pipe.connect("second.b", "first.y") + return pipe, [PipelineRunData({"first": {"x": 1}})] +``` + +## Why? + +As the time of writing, tests that invoke `Pipeline.run()` are scattered between different files with very little clarity on what they are intended to test - the only indicators are the name of each test itself and the name of their parent module. This makes it difficult to understand which behaviours are being tested, if they are tested redundantly or if they work correctly. + +The introduction of the Gherkin file allows for a single "source of truth" that enumerates (ideally, in an exhaustive manner) all the behaviours of the pipeline execution logic that we wish to test. This intermediate mapping of behaviours to actual test cases is meant to provide an overview of the latter and reduce the cognitive overhead of understanding them. When writing new tests, we now "tag" them with a specific behavioural parameter that's specified in a Gherkin scenario. + +This tag and behavioural parameter mapping is meant to be 1 to 1, meaning each "Given" step must map to one and only one function. If multiple function are marked with `@given("step name")` the last declaration will override all the previous ones. So it's important to verify that there are no other existing steps with the same name when adding a new one. + +While one could functionally do the same with well-defined test names and detailed comments on what is being tested, it would still lack the overview that the above approach provides. It's also extensible in that new scenarios with different behaviours can be introduced easily (e.g: for `async` pipeline execution logic). + +Apart from the above, the newly introduced harness ensures that all behavioural pipeline tests return a structured result, which simplifies checking of side-effects. diff --git a/test/core/pipeline/features/conftest.py b/test/core/pipeline/features/conftest.py new file mode 100644 index 00000000..c1486e28 --- /dev/null +++ b/test/core/pipeline/features/conftest.py @@ -0,0 +1,173 @@ +from dataclasses import dataclass, field +from typing import Generator, Tuple, List, Dict, Any, Set, Union +from pathlib import Path +import re +import asyncio + +import pytest +from pytest_bdd import when, then, parsers + +from haystack_experimental.core import AsyncPipeline, run_async_pipeline +import contextlib +import dataclasses +import uuid +from typing import Dict, Any, Optional, List, Iterator + +from haystack.tracing import Span, Tracer, enable_tracing, disable_tracing + + +PIPELINE_NAME_REGEX = re.compile(r"\[(.*)\]") + + +@dataclasses.dataclass +class SpyingSpan(Span): + operation_name: str + parent_span: Optional[Span] = None + tags: Dict[str, Any] = dataclasses.field(default_factory=dict) + + trace_id: Optional[str] = dataclasses.field( + default_factory=lambda: str(uuid.uuid4()) + ) + span_id: Optional[str] = dataclasses.field( + default_factory=lambda: str(uuid.uuid4()) + ) + + def set_tag(self, key: str, value: Any) -> None: + self.tags[key] = value + + def get_correlation_data_for_logs(self) -> Dict[str, Any]: + return {"trace_id": self.trace_id, "span_id": self.span_id} + + +class SpyingTracer(Tracer): + def current_span(self) -> Optional[Span]: + return self.spans[-1] if self.spans else None + + def __init__(self) -> None: + self.spans: List[SpyingSpan] = [] + + @contextlib.contextmanager + def trace( + self, + operation_name: str, + tags: Optional[Dict[str, Any]] = None, + parent_span: Optional[Span] = None, + ) -> Iterator[Span]: + new_span = SpyingSpan(operation_name, parent_span) + + for key, value in (tags or {}).items(): + new_span.set_tag(key, value) + + self.spans.append(new_span) + + yield new_span + + +@pytest.fixture() +def spying_tracer() -> Generator[SpyingTracer, None, None]: + tracer = SpyingTracer() + enable_tracing(tracer) + + yield tracer + + # Make sure to disable tracing after the test to avoid affecting other tests + disable_tracing() + + +@dataclass +class PipelineRunData: + """ + Holds the inputs and expected outputs for a single Pipeline run. + """ + + inputs: Dict[str, Any] + include_outputs_from: Set[str] = field(default_factory=set) + expected_outputs: Dict[str, Any] = field(default_factory=dict) + expected_run_order: List[str] = field(default_factory=list) + + +@dataclass +class _PipelineResult: + """ + Holds the outputs and the run order of a single Pipeline run. + """ + + outputs: Dict[str, Any] + run_order: List[str] + + +@when("I run the Pipeline", target_fixture="pipeline_result") +def run_pipeline( + pipeline_data: Tuple[AsyncPipeline, List[PipelineRunData]], spying_tracer +) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]: + """ + Attempts to run a pipeline with the given inputs. + `pipeline_data` is a tuple that must contain: + * A Pipeline instance + * The data to run the pipeline with + + If successful returns a tuple of the run outputs and the expected outputs. + In case an exceptions is raised returns that. + """ + pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1] + + results: List[_PipelineResult] = [] + + async def run_inner(data, include_outputs_from): + return await run_async_pipeline(pipeline, data.inputs, include_outputs_from) + + for data in pipeline_run_data: + try: + async_loop = asyncio.new_event_loop() + asyncio.set_event_loop(async_loop) + outputs = async_loop.run_until_complete( + run_inner(data, data.include_outputs_from) + ) + + run_order = [ + span.tags["haystack.component.name"] + for span in spying_tracer.spans + if "haystack.component.name" in span.tags + ] + results.append(_PipelineResult(outputs=outputs, run_order=run_order)) + spying_tracer.spans.clear() + except Exception as e: + return e + finally: + async_loop.close() + return [e for e in zip(results, pipeline_run_data)] + + +@then("draw it to file") +def draw_pipeline(pipeline_data: Tuple[AsyncPipeline, List[PipelineRunData]], request): + """ + Draw the pipeline to a file with the same name as the test. + """ + if m := PIPELINE_NAME_REGEX.search(request.node.name): + name = m.group(1).replace(" ", "_") + pipeline = pipeline_data[0] + graphs_dir = Path(request.config.rootpath) / "test_pipeline_graphs" + graphs_dir.mkdir(exist_ok=True) + pipeline.draw(graphs_dir / f"{name}.png") + + +@then("it should return the expected result") +def check_pipeline_result( + pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]] +): + for res, data in pipeline_result: + assert res.outputs == data.expected_outputs + + +@then("components ran in the expected order") +def check_pipeline_run_order( + pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]] +): + for res, data in pipeline_result: + assert res.run_order == data.expected_run_order + + +@pytest.mark.asyncio +@then(parsers.parse("it must have raised {exception_class_name}")) +def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str): + assert pipeline_result.__class__.__name__ == exception_class_name diff --git a/test/core/pipeline/features/pipeline_run.feature b/test/core/pipeline/features/pipeline_run.feature new file mode 100644 index 00000000..db064ea2 --- /dev/null +++ b/test/core/pipeline/features/pipeline_run.feature @@ -0,0 +1,57 @@ +Feature: Pipeline running + + Scenario Outline: Running a correct Pipeline + Given a pipeline + When I run the Pipeline + Then it should return the expected result + And components ran in the expected order + + Examples: + | kind | + | that has no components | + | that is linear | + | that is really complex with lots of components, forks, and loops | + | that has a single component with a default input | + | that has two loops of identical lengths | + | that has two loops of different lengths | + | that has a single loop with two conditional branches | + | that has a component with dynamic inputs defined in init | + | that has two branches that don't merge | + | that has three branches that don't merge | + | that has two branches that merge | + | that has different combinations of branches that merge and do not merge | + | that has two branches, one of which loops back | + | that has a component with mutable input | + | that has a component with mutable output sent to multiple inputs | + | that has a greedy and variadic component after a component with default input | + | that has components added in a different order from the order of execution | + | that has a component with only default inputs | + | that has a component with only default inputs as first to run and receives inputs from a loop | + | that has multiple branches that merge into a component with a single variadic input | + | that has multiple branches of different lengths that merge into a component with a single variadic input | + | that is linear and returns intermediate outputs | + | that has a loop and returns intermediate outputs from it | + | that is linear and returns intermediate outputs from multiple sockets | + | that has a component with default inputs that doesn't receive anything from its sender | + | that has a component with default inputs that doesn't receive anything from its sender but receives input from user | + | that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user | + | that has multiple components with only default inputs and are added in a different order from the order of execution | + | that is linear with conditional branching and multiple joins | + | that is a simple agent | + | that has a variadic component that receives partial inputs | + | that has an answer joiner variadic component | + | that is linear and a component in the middle receives optional input from other components and input from the user | + | that has a loop in the middle | + | that has variadic component that receives a conditional input | + | that has a string variadic component | + + Scenario Outline: Running a bad Pipeline + Given a pipeline + When I run the Pipeline + Then it must have raised + + Examples: + | kind | exception | + | that has an infinite loop | PipelineMaxComponentRuns | + | that has a component that doesn't return a dictionary | PipelineRuntimeError | + | that has a cycle that would get it stuck | PipelineRuntimeError | diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py new file mode 100644 index 00000000..a813ee33 --- /dev/null +++ b/test/core/pipeline/features/test_run.py @@ -0,0 +1,2611 @@ +import json +from typing import List, Optional, Dict, Any +import re + +from pytest_bdd import scenarios, given +import pytest + +from haystack import Document, component +from haystack.document_stores.types import DuplicatePolicy +from haystack.dataclasses import ChatMessage, GeneratedAnswer +from haystack.components.routers import ConditionalRouter +from haystack.components.builders import PromptBuilder, AnswerBuilder, ChatPromptBuilder +from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.components.joiners import ( + BranchJoiner, + DocumentJoiner, + AnswerJoiner, + StringJoiner, +) +from haystack.testing.sample_components import ( + Accumulate, + AddFixedValue, + Double, + Greet, + Parity, + Repeat, + Subtract, + Sum, + Threshold, + Remainder, + FString, + Hello, + TextSplitter, + StringListJoiner, +) +from haystack.testing.factory import component_class +from haystack_experimental.core import AsyncPipeline + +from test.core.pipeline.features.conftest import PipelineRunData + +pytestmark = pytest.mark.integration + +scenarios("pipeline_run.feature") + + +@given("a pipeline that has no components", target_fixture="pipeline_data") +def pipeline_that_has_no_components(): + pipeline = AsyncPipeline(max_runs_per_component=1) + inputs = {} + expected_outputs = {} + return pipeline, [PipelineRunData(inputs=inputs, expected_outputs=expected_outputs)] + + +@given("a pipeline that is linear", target_fixture="pipeline_data") +def pipeline_that_is_linear(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", Double()) + pipeline.connect("first_addition", "double") + pipeline.connect("double", "second_addition") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + expected_outputs={"second_addition": {"result": 7}}, + expected_run_order=["first_addition", "double", "second_addition"], + ) + ], + ) + + +@given("a pipeline that has an infinite loop", target_fixture="pipeline_data") +def pipeline_that_has_an_infinite_loop(): + routes = [ + { + "condition": "{{number > 2}}", + "output": "{{number}}", + "output_name": "big_number", + "output_type": int, + }, + { + "condition": "{{number <= 2}}", + "output": "{{number + 2}}", + "output_name": "small_number", + "output_type": int, + }, + ] + + main_input = BranchJoiner(int) + first_router = ConditionalRouter(routes=routes) + second_router = ConditionalRouter(routes=routes) + + pipe = AsyncPipeline(max_runs_per_component=1) + pipe.add_component("main_input", main_input) + pipe.add_component("first_router", first_router) + pipe.add_component("second_router", second_router) + + pipe.connect("main_input", "first_router.number") + pipe.connect("first_router.big_number", "second_router.number") + pipe.connect("second_router.big_number", "main_input") + + return pipe, [PipelineRunData({"main_input": {"value": 3}})] + + +@given( + "a pipeline that is really complex with lots of components, forks, and loops", + target_fixture="pipeline_data", +) +def pipeline_complex(): + pipeline = AsyncPipeline(max_runs_per_component=2) + pipeline.add_component("greet_first", Greet(message="Hello, the value is {value}.")) + pipeline.add_component("accumulate_1", Accumulate()) + pipeline.add_component("add_two", AddFixedValue(add=2)) + pipeline.add_component("parity", Parity()) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("accumulate_2", Accumulate()) + + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("below_10", Threshold(threshold=10)) + pipeline.add_component("double", Double()) + + pipeline.add_component( + "greet_again", Greet(message="Hello again, now the value is {value}.") + ) + pipeline.add_component("sum", Sum()) + + pipeline.add_component( + "greet_enumerator", + Greet(message="Hello from enumerator, here the value became {value}."), + ) + pipeline.add_component("enumerate", Repeat(outputs=["first", "second"])) + pipeline.add_component("add_three", AddFixedValue(add=3)) + + pipeline.add_component("diff", Subtract()) + pipeline.add_component( + "greet_one_last_time", Greet(message="Bye bye! The value here is {value}!") + ) + pipeline.add_component("replicate", Repeat(outputs=["first", "second"])) + pipeline.add_component("add_five", AddFixedValue(add=5)) + pipeline.add_component("add_four", AddFixedValue(add=4)) + pipeline.add_component("accumulate_3", Accumulate()) + + pipeline.connect("greet_first", "accumulate_1") + pipeline.connect("accumulate_1", "add_two") + pipeline.connect("add_two", "parity") + + pipeline.connect("parity.even", "greet_again") + pipeline.connect("greet_again", "sum.values") + pipeline.connect("sum", "diff.first_value") + pipeline.connect("diff", "greet_one_last_time") + pipeline.connect("greet_one_last_time", "replicate") + pipeline.connect("replicate.first", "add_five.value") + pipeline.connect("replicate.second", "add_four.value") + pipeline.connect("add_four", "accumulate_3") + + pipeline.connect("parity.odd", "add_one.value") + pipeline.connect("add_one", "branch_joiner.value") + pipeline.connect("branch_joiner", "below_10") + + pipeline.connect("below_10.below", "double") + pipeline.connect("double", "branch_joiner.value") + + pipeline.connect("below_10.above", "accumulate_2") + pipeline.connect("accumulate_2", "diff.second_value") + + pipeline.connect("greet_enumerator", "enumerate") + pipeline.connect("enumerate.second", "sum.values") + + pipeline.connect("enumerate.first", "add_three.value") + pipeline.connect("add_three", "sum.values") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"greet_first": {"value": 1}, "greet_enumerator": {"value": 1}}, + expected_outputs={ + "accumulate_3": {"value": -7}, + "add_five": {"result": -6}, + }, + expected_run_order=[ + "greet_first", + "greet_enumerator", + "accumulate_1", + "enumerate", + "add_two", + "add_three", + "parity", + "add_one", + "branch_joiner", + "below_10", + "double", + "branch_joiner", + "below_10", + "double", + "branch_joiner", + "below_10", + "accumulate_2", + "sum", + "diff", + "greet_one_last_time", + "replicate", + "add_five", + "add_four", + "accumulate_3", + ], + ) + ], + ) + + +@given( + "a pipeline that has a single component with a default input", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_single_component_with_a_default_input(): + @component + class WithDefault: + @component.output_types(b=int) + def run(self, a: int, b: int = 2): + return {"c": a + b} + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("with_defaults", WithDefault()) + + return ( + pipeline, + [ + PipelineRunData( + inputs={"with_defaults": {"a": 40, "b": 30}}, + expected_outputs={"with_defaults": {"c": 70}}, + expected_run_order=["with_defaults"], + ), + PipelineRunData( + inputs={"with_defaults": {"a": 40}}, + expected_outputs={"with_defaults": {"c": 42}}, + expected_run_order=["with_defaults"], + ), + ], + ) + + +@given( + "a pipeline that has two loops of identical lengths", target_fixture="pipeline_data" +) +def pipeline_that_has_two_loops_of_identical_lengths(): + pipeline = AsyncPipeline(max_runs_per_component=10) + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("remainder", Remainder(divisor=3)) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("add_two", AddFixedValue(add=2)) + + pipeline.connect("branch_joiner.value", "remainder.value") + pipeline.connect("remainder.remainder_is_1", "add_two.value") + pipeline.connect("remainder.remainder_is_2", "add_one.value") + pipeline.connect("add_two", "branch_joiner.value") + pipeline.connect("add_one", "branch_joiner.value") + return ( + pipeline, + [ + PipelineRunData( + inputs={"branch_joiner": {"value": 0}}, + expected_outputs={"remainder": {"remainder_is_0": 0}}, + expected_run_order=["branch_joiner", "remainder"], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 3}}, + expected_outputs={"remainder": {"remainder_is_0": 3}}, + expected_run_order=["branch_joiner", "remainder"], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 4}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=[ + "branch_joiner", + "remainder", + "add_two", + "branch_joiner", + "remainder", + ], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 5}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=[ + "branch_joiner", + "remainder", + "add_one", + "branch_joiner", + "remainder", + ], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 6}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=["branch_joiner", "remainder"], + ), + ], + ) + + +@given( + "a pipeline that has two loops of different lengths", target_fixture="pipeline_data" +) +def pipeline_that_has_two_loops_of_different_lengths(): + pipeline = AsyncPipeline(max_runs_per_component=10) + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("remainder", Remainder(divisor=3)) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("add_two_1", AddFixedValue(add=1)) + pipeline.add_component("add_two_2", AddFixedValue(add=1)) + + pipeline.connect("branch_joiner.value", "remainder.value") + pipeline.connect("remainder.remainder_is_1", "add_two_1.value") + pipeline.connect("add_two_1", "add_two_2.value") + pipeline.connect("add_two_2", "branch_joiner") + pipeline.connect("remainder.remainder_is_2", "add_one.value") + pipeline.connect("add_one", "branch_joiner") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"branch_joiner": {"value": 0}}, + expected_outputs={"remainder": {"remainder_is_0": 0}}, + expected_run_order=["branch_joiner", "remainder"], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 3}}, + expected_outputs={"remainder": {"remainder_is_0": 3}}, + expected_run_order=["branch_joiner", "remainder"], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 4}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=[ + "branch_joiner", + "remainder", + "add_two_1", + "add_two_2", + "branch_joiner", + "remainder", + ], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 5}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=[ + "branch_joiner", + "remainder", + "add_one", + "branch_joiner", + "remainder", + ], + ), + PipelineRunData( + inputs={"branch_joiner": {"value": 6}}, + expected_outputs={"remainder": {"remainder_is_0": 6}}, + expected_run_order=["branch_joiner", "remainder"], + ), + ], + ) + + +@given( + "a pipeline that has a single loop with two conditional branches", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_single_loop_with_two_conditional_branches(): + accumulator = Accumulate() + pipeline = AsyncPipeline(max_runs_per_component=10) + + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("below_10", Threshold(threshold=10)) + pipeline.add_component("below_5", Threshold(threshold=5)) + pipeline.add_component("add_three", AddFixedValue(add=3)) + pipeline.add_component("accumulator", accumulator) + pipeline.add_component("add_two", AddFixedValue(add=2)) + + pipeline.connect("add_one.result", "branch_joiner") + pipeline.connect("branch_joiner.value", "below_10.value") + pipeline.connect("below_10.below", "accumulator.value") + pipeline.connect("accumulator.value", "below_5.value") + pipeline.connect("below_5.above", "add_three.value") + pipeline.connect("below_5.below", "branch_joiner") + pipeline.connect("add_three.result", "branch_joiner") + pipeline.connect("below_10.above", "add_two.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_one": {"value": 3}}, + expected_outputs={"add_two": {"result": 13}}, + expected_run_order=[ + "add_one", + "branch_joiner", + "below_10", + "accumulator", + "below_5", + "branch_joiner", + "below_10", + "accumulator", + "below_5", + "add_three", + "branch_joiner", + "below_10", + "add_two", + ], + ) + ], + ) + + +@given( + "a pipeline that has a component with dynamic inputs defined in init", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_dynamic_inputs_defined_in_init(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("hello", Hello()) + pipeline.add_component( + "fstring", + FString(template="This is the greeting: {greeting}!", variables=["greeting"]), + ) + pipeline.add_component("splitter", TextSplitter()) + pipeline.connect("hello.output", "fstring.greeting") + pipeline.connect("fstring.string", "splitter.sentence") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"hello": {"word": "Alice"}}, + expected_outputs={ + "splitter": { + "output": [ + "This", + "is", + "the", + "greeting:", + "Hello,", + "Alice!!", + ] + } + }, + expected_run_order=["hello", "fstring", "splitter"], + ), + PipelineRunData( + inputs={ + "hello": {"word": "Alice"}, + "fstring": {"template": "Received: {greeting}"}, + }, + expected_outputs={ + "splitter": {"output": ["Received:", "Hello,", "Alice!"]} + }, + expected_run_order=["hello", "fstring", "splitter"], + ), + ], + ) + + +@given( + "a pipeline that has two branches that don't merge", target_fixture="pipeline_data" +) +def pipeline_that_has_two_branches_that_dont_merge(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("parity", Parity()) + pipeline.add_component("add_ten", AddFixedValue(add=10)) + pipeline.add_component("double", Double()) + pipeline.add_component("add_three", AddFixedValue(add=3)) + + pipeline.connect("add_one.result", "parity.value") + pipeline.connect("parity.even", "add_ten.value") + pipeline.connect("parity.odd", "double.value") + pipeline.connect("add_ten.result", "add_three.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_one": {"value": 1}}, + expected_outputs={"add_three": {"result": 15}}, + expected_run_order=["add_one", "parity", "add_ten", "add_three"], + ), + PipelineRunData( + inputs={"add_one": {"value": 2}}, + expected_outputs={"double": {"value": 6}}, + expected_run_order=["add_one", "parity", "double"], + ), + ], + ) + + +@given( + "a pipeline that has three branches that don't merge", + target_fixture="pipeline_data", +) +def pipeline_that_has_three_branches_that_dont_merge(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("repeat", Repeat(outputs=["first", "second"])) + pipeline.add_component("add_ten", AddFixedValue(add=10)) + pipeline.add_component("double", Double()) + pipeline.add_component("add_three", AddFixedValue(add=3)) + pipeline.add_component("add_one_again", AddFixedValue(add=1)) + + pipeline.connect("add_one.result", "repeat.value") + pipeline.connect("repeat.first", "add_ten.value") + pipeline.connect("repeat.second", "double.value") + pipeline.connect("repeat.second", "add_three.value") + pipeline.connect("add_three.result", "add_one_again.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_one": {"value": 1}}, + expected_outputs={ + "add_one_again": {"result": 6}, + "add_ten": {"result": 12}, + "double": {"value": 4}, + }, + expected_run_order=[ + "add_one", + "repeat", + "add_ten", + "double", + "add_three", + "add_one_again", + ], + ) + ], + ) + + +@given("a pipeline that has two branches that merge", target_fixture="pipeline_data") +def pipeline_that_has_two_branches_that_merge(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue(add=2)) + pipeline.add_component("third_addition", AddFixedValue(add=2)) + pipeline.add_component("diff", Subtract()) + pipeline.add_component("fourth_addition", AddFixedValue(add=1)) + + pipeline.connect("first_addition.result", "second_addition.value") + pipeline.connect("second_addition.result", "diff.first_value") + pipeline.connect("third_addition.result", "diff.second_value") + pipeline.connect("diff", "fourth_addition.value") + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}, "third_addition": {"value": 1}}, + expected_outputs={"fourth_addition": {"result": 3}}, + expected_run_order=[ + "first_addition", + "third_addition", + "second_addition", + "diff", + "fourth_addition", + ], + ) + ], + ) + + +@given( + "a pipeline that has different combinations of branches that merge and do not merge", + target_fixture="pipeline_data", +) +def pipeline_that_has_different_combinations_of_branches_that_merge_and_do_not_merge(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("add_one", AddFixedValue()) + pipeline.add_component("parity", Parity()) + pipeline.add_component("add_ten", AddFixedValue(add=10)) + pipeline.add_component("double", Double()) + pipeline.add_component("add_four", AddFixedValue(add=4)) + pipeline.add_component("add_two", AddFixedValue()) + pipeline.add_component("add_two_as_well", AddFixedValue()) + pipeline.add_component("diff", Subtract()) + + pipeline.connect("add_one.result", "parity.value") + pipeline.connect("parity.even", "add_four.value") + pipeline.connect("parity.odd", "double.value") + pipeline.connect("add_ten.result", "diff.first_value") + pipeline.connect("double.value", "diff.second_value") + pipeline.connect("parity.odd", "add_ten.value") + pipeline.connect("add_four.result", "add_two.value") + pipeline.connect("add_four.result", "add_two_as_well.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "add_one": {"value": 1}, + "add_two": {"add": 2}, + "add_two_as_well": {"add": 2}, + }, + expected_outputs={ + "add_two": {"result": 8}, + "add_two_as_well": {"result": 8}, + }, + expected_run_order=[ + "add_one", + "parity", + "add_four", + "add_two", + "add_two_as_well", + ], + ), + PipelineRunData( + inputs={ + "add_one": {"value": 2}, + "add_two": {"add": 2}, + "add_two_as_well": {"add": 2}, + }, + expected_outputs={"diff": {"difference": 7}}, + expected_run_order=["add_one", "parity", "double", "add_ten", "diff"], + ), + ], + ) + + +@given( + "a pipeline that has two branches, one of which loops back", + target_fixture="pipeline_data", +) +def pipeline_that_has_two_branches_one_of_which_loops_back(): + pipeline = AsyncPipeline(max_runs_per_component=10) + pipeline.add_component("add_zero", AddFixedValue(add=0)) + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("sum", Sum()) + pipeline.add_component("below_10", Threshold(threshold=10)) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("counter", Accumulate()) + pipeline.add_component("add_two", AddFixedValue(add=2)) + + pipeline.connect("add_zero", "branch_joiner.value") + pipeline.connect("branch_joiner", "below_10.value") + pipeline.connect("below_10.below", "add_one.value") + pipeline.connect("add_one.result", "counter.value") + pipeline.connect("counter.value", "branch_joiner.value") + pipeline.connect("below_10.above", "add_two.value") + pipeline.connect("add_two.result", "sum.values") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_zero": {"value": 8}, "sum": {"values": 2}}, + expected_outputs={"sum": {"total": 23}}, + expected_run_order=[ + "add_zero", + "branch_joiner", + "below_10", + "add_one", + "counter", + "branch_joiner", + "below_10", + "add_one", + "counter", + "branch_joiner", + "below_10", + "add_two", + "sum", + ], + ) + ], + ) + + +@given( + "a pipeline that has a component with mutable input", target_fixture="pipeline_data" +) +def pipeline_that_has_a_component_with_mutable_input(): + @component + class InputMangler: + @component.output_types(mangled_list=List[str]) + def run(self, input_list: List[str]): + input_list.append("extra_item") + return {"mangled_list": input_list} + + pipe = AsyncPipeline(max_runs_per_component=1) + pipe.add_component("mangler1", InputMangler()) + pipe.add_component("mangler2", InputMangler()) + pipe.add_component("concat1", StringListJoiner()) + pipe.add_component("concat2", StringListJoiner()) + pipe.connect("mangler1", "concat1") + pipe.connect("mangler2", "concat2") + + input_list = ["foo", "bar"] + + return ( + pipe, + [ + PipelineRunData( + inputs={ + "mangler1": {"input_list": input_list}, + "mangler2": {"input_list": input_list}, + }, + expected_outputs={ + "concat1": {"output": ["foo", "bar", "extra_item"]}, + "concat2": {"output": ["foo", "bar", "extra_item"]}, + }, + expected_run_order=["mangler1", "mangler2", "concat1", "concat2"], + ) + ], + ) + + +@given( + "a pipeline that has a component with mutable output sent to multiple inputs", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_mutable_output_sent_to_multiple_inputs(): + @component + class PassThroughPromptBuilder: + # This is a pass-through component that returns the same input + @component.output_types(prompt=List[ChatMessage]) + def run(self, prompt_source: List[ChatMessage]): + return {"prompt": prompt_source} + + @component + class MessageMerger: + @component.output_types(merged_message=str) + def run(self, messages: List[ChatMessage], metadata: dict = None): + return {"merged_message": "\n".join(t.content for t in messages)} + + @component + class FakeGenerator: + # This component is a fake generator that always returns the same message + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage]): + return {"replies": [ChatMessage.from_assistant("Fake message")]} + + prompt_builder = PassThroughPromptBuilder() + llm = FakeGenerator() + mm1 = MessageMerger() + mm2 = MessageMerger() + + pipe = AsyncPipeline(max_runs_per_component=1) + pipe.add_component("prompt_builder", prompt_builder) + pipe.add_component("llm", llm) + pipe.add_component("mm1", mm1) + pipe.add_component("mm2", mm2) + + pipe.connect("prompt_builder.prompt", "llm.messages") + pipe.connect("prompt_builder.prompt", "mm1") + pipe.connect("llm.replies", "mm2") + + messages = [ + ChatMessage.from_system( + "Always respond in English even if some input data is in other languages." + ), + ChatMessage.from_user("Tell me about Berlin"), + ] + params = {"metadata": {"metadata_key": "metadata_value", "meta2": "value2"}} + + return ( + pipe, + [ + PipelineRunData( + inputs={ + "mm1": params, + "mm2": params, + "prompt_builder": {"prompt_source": messages}, + }, + expected_outputs={ + "mm1": { + "merged_message": "Always respond " + "in English even " + "if some input " + "data is in other " + "languages.\n" + "Tell me about " + "Berlin" + }, + "mm2": {"merged_message": "Fake message"}, + }, + expected_run_order=["prompt_builder", "llm", "mm1", "mm2"], + ) + ], + ) + + +@given( + "a pipeline that has a greedy and variadic component after a component with default input", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_greedy_and_variadic_component_after_a_component_with_default_input(): + """ + This test verifies that `Pipeline.run()` executes the components in the correct order when + there's a greedy Component with variadic input right before a Component with at least one default input. + + We use the `spying_tracer` fixture to simplify the code to verify the order of execution. + This creates some coupling between this test and how we trace the Pipeline execution. + A worthy tradeoff in my opinion, we will notice right away if we change either the run logic or + the tracing logic. + """ + document_store = InMemoryDocumentStore() + document_store.write_documents([Document(content="This is a simple document")]) + + pipeline = AsyncPipeline(max_runs_per_component=1) + template = "Given this documents: {{ documents|join(', ', attribute='content') }} Answer this question: {{ query }}" + pipeline.add_component( + "retriever", InMemoryBM25Retriever(document_store=document_store) + ) + pipeline.add_component("prompt_builder", PromptBuilder(template=template)) + pipeline.add_component("branch_joiner", BranchJoiner(List[Document])) + + pipeline.connect("retriever", "branch_joiner") + pipeline.connect("branch_joiner", "prompt_builder.documents") + return ( + pipeline, + [ + PipelineRunData( + inputs={"query": "This is my question"}, + expected_outputs={ + "prompt_builder": { + "prompt": "Given this " + "documents: " + "This is a " + "simple " + "document " + "Answer this " + "question: " + "This is my " + "question" + } + }, + expected_run_order=["retriever", "branch_joiner", "prompt_builder"], + ) + ], + ) + + +@given( + "a pipeline that has a component that doesn't return a dictionary", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_that_doesnt_return_a_dictionary(): + BrokenComponent = component_class( + "BrokenComponent", + input_types={"a": int}, + output_types={"b": int}, + output=1, # type:ignore + ) + + pipe = AsyncPipeline(max_runs_per_component=10) + pipe.add_component("comp", BrokenComponent()) + return pipe, [PipelineRunData({"comp": {"a": 1}})] + + +@given( + "a pipeline that has components added in a different order from the order of execution", + target_fixture="pipeline_data", +) +def pipeline_that_has_components_added_in_a_different_order_from_the_order_of_execution(): + """ + We enqueue the Components in internal `to_run` data structure at the start of `Pipeline.run()` using the order + they are added in the Pipeline with `Pipeline.add_component()`. + If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline + logic A must be executed after B it could run instead before. + + This test verifies that the order of execution is correct. + """ + docs = [ + Document(content="Rome is the capital of Italy"), + Document(content="Paris is the capital of France"), + ] + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs) + template = ( + "Given the following information, answer the question.\n" + "Context:\n" + "{% for document in documents %}" + " {{ document.content }}\n" + "{% endfor %}" + "Question: {{ query }}" + ) + + pipe = AsyncPipeline(max_runs_per_component=1) + + # The order of this addition is important for the test + # Do not edit them. + pipe.add_component("prompt_builder", PromptBuilder(template=template)) + pipe.add_component("retriever", InMemoryBM25Retriever(document_store=doc_store)) + pipe.connect("retriever", "prompt_builder.documents") + + query = "What is the capital of France?" + return ( + pipe, + [ + PipelineRunData( + inputs={ + "prompt_builder": {"query": query}, + "retriever": {"query": query}, + }, + expected_outputs={ + "prompt_builder": { + "prompt": "Given the " + "following " + "information, " + "answer the " + "question.\n" + "Context:\n" + " Paris is " + "the capital " + "of France\n" + " Rome is " + "the capital " + "of Italy\n" + "Question: " + "What is the " + "capital of " + "France?" + } + }, + expected_run_order=["retriever", "prompt_builder"], + ) + ], + ) + + +@given( + "a pipeline that has a component with only default inputs", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_only_default_inputs(): + FakeGenerator = component_class( + "FakeGenerator", + input_types={"prompt": str}, + output_types={"replies": List[str]}, + output={"replies": ["Paris"]}, + ) + docs = [ + Document(content="Rome is the capital of Italy"), + Document(content="Paris is the capital of France"), + ] + doc_store = InMemoryDocumentStore() + doc_store.write_documents(docs) + template = ( + "Given the following information, answer the question.\n" + "Context:\n" + "{% for document in documents %}" + " {{ document.content }}\n" + "{% endfor %}" + "Question: {{ query }}" + ) + + pipe = AsyncPipeline(max_runs_per_component=1) + + pipe.add_component("retriever", InMemoryBM25Retriever(document_store=doc_store)) + pipe.add_component("prompt_builder", PromptBuilder(template=template)) + pipe.add_component("generator", FakeGenerator()) + pipe.add_component("answer_builder", AnswerBuilder()) + + pipe.connect("retriever", "prompt_builder.documents") + pipe.connect("prompt_builder.prompt", "generator.prompt") + pipe.connect("generator.replies", "answer_builder.replies") + pipe.connect("retriever.documents", "answer_builder.documents") + + return ( + pipe, + [ + PipelineRunData( + inputs={"query": "What is the capital of France?"}, + expected_outputs={ + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Paris", + query="What " "is " "the " "capital " "of " "France?", + documents=[ + Document( + id="413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", + content="Paris is the capital of France", + score=1.600237583702734, + ), + Document( + id="a4a874fc2ef75015da7924d709fbdd2430e46a8e94add6e0f26cd32c1c03435d", + content="Rome is the capital of Italy", + score=1.2536639934227616, + ), + ], + meta={}, + ) + ] + } + }, + expected_run_order=[ + "retriever", + "prompt_builder", + "generator", + "answer_builder", + ], + ) + ], + ) + + +@given( + "a pipeline that has a component with only default inputs as first to run and receives inputs from a loop", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run_and_receives_inputs_from_a_loop(): + """ + This tests verifies that a Pipeline doesn't get stuck running in a loop if + it has all the following characterics: + - The first Component has all defaults for its inputs + - The first Component receives one input from the user + - The first Component receives one input from a loop in the Pipeline + - The second Component has at least one default input + """ + + def fake_generator_run( + self, generation_kwargs: Optional[Dict[str, Any]] = None, **kwargs + ): + # Simple hack to simulate a model returning a different reply after the + # the first time it's called + if getattr(fake_generator_run, "called", False): + return {"replies": ["Rome"]} + fake_generator_run.called = True + return {"replies": ["Paris"]} + + FakeGenerator = component_class( + "FakeGenerator", + input_types={"prompt": str}, + output_types={"replies": List[str]}, + extra_fields={"run": fake_generator_run}, + ) + template = ( + "Answer the following question.\n" + "{% if previous_replies %}\n" + "Previously you replied incorrectly this:\n" + "{% for reply in previous_replies %}\n" + " - {{ reply }}\n" + "{% endfor %}\n" + "{% endif %}\n" + "Question: {{ query }}" + ) + router = ConditionalRouter( + routes=[ + { + "condition": "{{ replies == ['Rome'] }}", + "output": "{{ replies }}", + "output_name": "correct_replies", + "output_type": List[int], + }, + { + "condition": "{{ replies == ['Paris'] }}", + "output": "{{ replies }}", + "output_name": "incorrect_replies", + "output_type": List[int], + }, + ] + ) + + pipe = AsyncPipeline(max_runs_per_component=1) + + pipe.add_component("prompt_builder", PromptBuilder(template=template)) + pipe.add_component("generator", FakeGenerator()) + pipe.add_component("router", router) + + pipe.connect("prompt_builder.prompt", "generator.prompt") + pipe.connect("generator.replies", "router.replies") + pipe.connect("router.incorrect_replies", "prompt_builder.previous_replies") + + return ( + pipe, + [ + PipelineRunData( + inputs={ + "prompt_builder": {"query": "What is the capital of " "Italy?"} + }, + expected_outputs={"router": {"correct_replies": ["Rome"]}}, + expected_run_order=[ + "prompt_builder", + "generator", + "router", + "prompt_builder", + "generator", + "router", + ], + ) + ], + ) + + +@given( + "a pipeline that has multiple branches that merge into a component with a single variadic input", + target_fixture="pipeline_data", +) +def pipeline_that_has_multiple_branches_that_merge_into_a_component_with_a_single_variadic_input(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("add_one", AddFixedValue()) + pipeline.add_component("parity", Remainder(divisor=2)) + pipeline.add_component("add_ten", AddFixedValue(add=10)) + pipeline.add_component("double", Double()) + pipeline.add_component("add_four", AddFixedValue(add=4)) + pipeline.add_component("add_one_again", AddFixedValue()) + pipeline.add_component("sum", Sum()) + + pipeline.connect("add_one.result", "parity.value") + pipeline.connect("parity.remainder_is_0", "add_ten.value") + pipeline.connect("parity.remainder_is_1", "double.value") + pipeline.connect("add_one.result", "sum.values") + pipeline.connect("add_ten.result", "sum.values") + pipeline.connect("double.value", "sum.values") + pipeline.connect("parity.remainder_is_1", "add_four.value") + pipeline.connect("add_four.result", "add_one_again.value") + pipeline.connect("add_one_again.result", "sum.values") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_one": {"value": 1}}, + expected_outputs={"sum": {"total": 14}}, + expected_run_order=["add_one", "parity", "add_ten", "sum"], + ), + PipelineRunData( + inputs={"add_one": {"value": 2}}, + expected_outputs={"sum": {"total": 17}}, + expected_run_order=[ + "add_one", + "parity", + "double", + "add_four", + "add_one_again", + "sum", + ], + ), + ], + ) + + +@given( + "a pipeline that has multiple branches of different lengths that merge into a component with a single variadic input", + target_fixture="pipeline_data", +) +def pipeline_that_has_multiple_branches_of_different_lengths_that_merge_into_a_component_with_a_single_variadic_input(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue(add=2)) + pipeline.add_component("third_addition", AddFixedValue(add=2)) + pipeline.add_component("sum", Sum()) + pipeline.add_component("fourth_addition", AddFixedValue(add=1)) + + pipeline.connect("first_addition.result", "second_addition.value") + pipeline.connect("first_addition.result", "sum.values") + pipeline.connect("second_addition.result", "sum.values") + pipeline.connect("third_addition.result", "sum.values") + pipeline.connect("sum.total", "fourth_addition.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}, "third_addition": {"value": 1}}, + expected_outputs={"fourth_addition": {"result": 12}}, + expected_run_order=[ + "first_addition", + "third_addition", + "second_addition", + "sum", + "fourth_addition", + ], + ) + ], + ) + + +@given( + "a pipeline that is linear and returns intermediate outputs", + target_fixture="pipeline_data", +) +def pipeline_that_is_linear_and_returns_intermediate_outputs(): + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", Double()) + pipeline.connect("first_addition", "double") + pipeline.connect("double", "second_addition") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + include_outputs_from={"second_addition", "double", "first_addition"}, + expected_outputs={ + "double": {"value": 6}, + "first_addition": {"result": 3}, + "second_addition": {"result": 7}, + }, + expected_run_order=["first_addition", "double", "second_addition"], + ), + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + include_outputs_from={"double"}, + expected_outputs={ + "double": {"value": 6}, + "second_addition": {"result": 7}, + }, + expected_run_order=["first_addition", "double", "second_addition"], + ), + ], + ) + + +@given( + "a pipeline that has a loop and returns intermediate outputs from it", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_loop_and_returns_intermediate_outputs_from_it(): + pipeline = AsyncPipeline(max_runs_per_component=10) + pipeline.add_component("add_one", AddFixedValue(add=1)) + pipeline.add_component("branch_joiner", BranchJoiner(type_=int)) + pipeline.add_component("below_10", Threshold(threshold=10)) + pipeline.add_component("below_5", Threshold(threshold=5)) + pipeline.add_component("add_three", AddFixedValue(add=3)) + pipeline.add_component("accumulator", Accumulate()) + pipeline.add_component("add_two", AddFixedValue(add=2)) + + pipeline.connect("add_one.result", "branch_joiner") + pipeline.connect("branch_joiner.value", "below_10.value") + pipeline.connect("below_10.below", "accumulator.value") + pipeline.connect("accumulator.value", "below_5.value") + pipeline.connect("below_5.above", "add_three.value") + pipeline.connect("below_5.below", "branch_joiner") + pipeline.connect("add_three.result", "branch_joiner") + pipeline.connect("below_10.above", "add_two.value") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"add_one": {"value": 3}}, + include_outputs_from={ + "add_two", + "add_one", + "branch_joiner", + "below_10", + "accumulator", + "below_5", + "add_three", + }, + expected_outputs={ + "add_two": {"result": 13}, + "add_one": {"result": 4}, + "branch_joiner": {"value": 11}, + "below_10": {"above": 11}, + "accumulator": {"value": 8}, + "below_5": {"above": 8}, + "add_three": {"result": 11}, + }, + expected_run_order=[ + "add_one", + "branch_joiner", + "below_10", + "accumulator", + "below_5", + "branch_joiner", + "below_10", + "accumulator", + "below_5", + "add_three", + "branch_joiner", + "below_10", + "add_two", + ], + ) + ], + ) + + +@given( + "a pipeline that is linear and returns intermediate outputs from multiple sockets", + target_fixture="pipeline_data", +) +def pipeline_that_is_linear_and_returns_intermediate_outputs_from_multiple_sockets(): + @component + class DoubleWithOriginal: + """ + Doubles the input value and returns the original value as well. + """ + + @component.output_types(value=int, original=int) + def run(self, value: int): + return {"value": value * 2, "original": value} + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", DoubleWithOriginal()) + pipeline.connect("first_addition", "double") + pipeline.connect("double.value", "second_addition") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + include_outputs_from={"second_addition", "double", "first_addition"}, + expected_outputs={ + "double": {"original": 3, "value": 6}, + "first_addition": {"result": 3}, + "second_addition": {"result": 7}, + }, + expected_run_order=["first_addition", "double", "second_addition"], + ), + PipelineRunData( + inputs={"first_addition": {"value": 1}}, + include_outputs_from={"double"}, + expected_outputs={ + "double": {"original": 3, "value": 6}, + "second_addition": {"result": 7}, + }, + expected_run_order=["first_addition", "double", "second_addition"], + ), + ], + ) + + +@given( + "a pipeline that has a component with default inputs that doesn't receive anything from its sender", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender(): + routes = [ + { + "condition": "{{'reisen' in sentence}}", + "output": "German", + "output_name": "language_1", + "output_type": str, + }, + { + "condition": "{{'viajar' in sentence}}", + "output": "Spanish", + "output_name": "language_2", + "output_type": str, + }, + ] + router = ConditionalRouter(routes) + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("router", router) + pipeline.add_component( + "pb", PromptBuilder(template="Ok, I know, that's {{language}}") + ) + pipeline.connect("router.language_2", "pb.language") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"router": {"sentence": "Wir mussen reisen"}}, + expected_outputs={"router": {"language_1": "German"}}, + expected_run_order=["router"], + ), + PipelineRunData( + inputs={"router": {"sentence": "Yo tengo que viajar"}}, + expected_outputs={"pb": {"prompt": "Ok, I know, that's Spanish"}}, + expected_run_order=["router", "pb"], + ), + ], + ) + + +@given( + "a pipeline that has a component with default inputs that doesn't receive anything from its sender but receives input from user", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user(): + prompt = PromptBuilder( + template="""Please generate an SQL query. The query should answer the following Question: {{ question }}; + If the question cannot be answered given the provided table and columns, return 'no_answer' + The query is to be answered for the table is called 'absenteeism' with the following + Columns: {{ columns }}; + Answer:""" + ) + + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + if "no_answer" in prompt: + return {"replies": ["There's simply no_answer to this question"]} + return {"replies": ["Some SQL query"]} + + @component + class FakeSQLQuerier: + @component.output_types(results=str) + def run(self, query: str): + return {"results": "This is the query result", "query": query} + + llm = FakeGenerator() + sql_querier = FakeSQLQuerier() + + routes = [ + { + "condition": "{{'no_answer' not in replies[0]}}", + "output": "{{replies[0]}}", + "output_name": "sql", + "output_type": str, + }, + { + "condition": "{{'no_answer' in replies[0]}}", + "output": "{{question}}", + "output_name": "go_to_fallback", + "output_type": str, + }, + ] + + router = ConditionalRouter(routes) + + fallback_prompt = PromptBuilder( + template="""User entered a query that cannot be answered with the given table. + The query was: {{ question }} and the table had columns: {{ columns }}. + Let the user know why the question cannot be answered""" + ) + fallback_llm = FakeGenerator() + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("prompt", prompt) + pipeline.add_component("llm", llm) + pipeline.add_component("router", router) + pipeline.add_component("fallback_prompt", fallback_prompt) + pipeline.add_component("fallback_llm", fallback_llm) + pipeline.add_component("sql_querier", sql_querier) + + pipeline.connect("prompt", "llm") + pipeline.connect("llm.replies", "router.replies") + pipeline.connect("router.sql", "sql_querier.query") + pipeline.connect("router.go_to_fallback", "fallback_prompt.question") + pipeline.connect("fallback_prompt", "fallback_llm") + + columns = "Age, Absenteeism_time_in_hours, Days, Disciplinary_failure" + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "prompt": { + "question": "This is a question with no_answer", + "columns": columns, + }, + "router": {"question": "This is a question with no_answer"}, + }, + expected_outputs={ + "fallback_llm": { + "replies": ["There's simply no_answer to this question"] + } + }, + expected_run_order=[ + "prompt", + "llm", + "router", + "fallback_prompt", + "fallback_llm", + ], + ) + ], + [ + PipelineRunData( + inputs={ + "prompt": { + "question": "This is a question that has an answer", + "columns": columns, + }, + "router": {"question": "This is a question that has an answer"}, + }, + expected_outputs={ + "sql_querier": { + "results": "This is the query result", + "query": "Some SQL query", + } + }, + expected_run_order=["prompt", "llm", "router", "sql_querier"], + ) + ], + ) + + +@given( + "a pipeline that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user", + target_fixture="pipeline_data", +) +def pipeline_that_has_a_loop_and_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user(): + template = """ + You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n + Negative experience labels: + - Late delivery + - Rotten/spoilt item + - Bad Courier behavior + + Positive experience labels: + - Good courier behavior + - Thanks & appreciation + - Love message to courier + - Fast delivery + - Quality of products + + Create a JSON object as a response. The fields are: 'positive_experience', 'negative_experience'. + Assign at least one of the pre-defined labels to the given customer comment under positive and negative experience fields. + If the comment has a positive experience, list the label under 'positive_experience' field. + If the comments has a negative_experience, list it under the 'negative_experience' field. + Here is the comment:\n{{ comment }}\n. Just return the category names in the list. If there aren't any, return an empty list. + + {% if invalid_replies and error_message %} + You already created the following output in a previous attempt: {{ invalid_replies }} + However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message }} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} + """ + prompt_builder = PromptBuilder(template=template) + + @component + class FakeOutputValidator: + @component.output_types( + valid_replies=List[str], + invalid_replies=Optional[List[str]], + error_message=Optional[str], + ) + def run(self, replies: List[str]): + if not getattr(self, "called", False): + self.called = True + return { + "invalid_replies": ["This is an invalid reply"], + "error_message": "this is an error message", + } + return {"valid_replies": replies} + + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + return {"replies": ["This is a valid reply"]} + + llm = FakeGenerator() + validator = FakeOutputValidator() + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("prompt_builder", prompt_builder) + + pipeline.add_component("llm", llm) + pipeline.add_component("output_validator", validator) + + pipeline.connect("prompt_builder.prompt", "llm.prompt") + pipeline.connect("llm.replies", "output_validator.replies") + pipeline.connect( + "output_validator.invalid_replies", "prompt_builder.invalid_replies" + ) + + pipeline.connect("output_validator.error_message", "prompt_builder.error_message") + + comment = "I loved the quality of the meal but the courier was rude" + return ( + pipeline, + [ + PipelineRunData( + inputs={"prompt_builder": {"template_variables": {"comment": comment}}}, + expected_outputs={ + "output_validator": {"valid_replies": ["This is a valid reply"]} + }, + expected_run_order=[ + "prompt_builder", + "llm", + "output_validator", + "prompt_builder", + "llm", + "output_validator", + ], + ) + ], + ) + + +@given( + "a pipeline that has multiple components with only default inputs and are added in a different order from the order of execution", + target_fixture="pipeline_data", +) +def pipeline_that_has_multiple_components_with_only_default_inputs_and_are_added_in_a_different_order_from_the_order_of_execution(): + prompt_builder1 = PromptBuilder( + template=""" + You are a spellchecking system. Check the given query and fill in the corrected query. + + Question: {{question}} + Corrected question: + """ + ) + prompt_builder2 = PromptBuilder( + template=""" + According to these documents: + + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + Answer the given question: {{question}} + Answer: + """ + ) + prompt_builder3 = PromptBuilder( + template=""" + {% for ans in replies %} + {{ ans }} + {% endfor %} + """ + ) + + @component + class FakeRetriever: + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + ): + return {"documents": [Document(content="This is a document")]} + + @component + class FakeRanker: + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + calibration_factor: Optional[float] = None, + score_threshold: Optional[float] = None, + ): + return {"documents": documents} + + @component + class FakeGenerator: + @component.output_types(replies=List[str], meta=Dict[str, Any]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + return {"replies": ["This is a reply"], "meta": {"meta_key": "meta_value"}} + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component(name="retriever", instance=FakeRetriever()) + pipeline.add_component(name="ranker", instance=FakeRanker()) + pipeline.add_component(name="prompt_builder2", instance=prompt_builder2) + pipeline.add_component(name="prompt_builder1", instance=prompt_builder1) + pipeline.add_component(name="prompt_builder3", instance=prompt_builder3) + pipeline.add_component(name="llm", instance=FakeGenerator()) + pipeline.add_component(name="spellchecker", instance=FakeGenerator()) + + pipeline.connect("prompt_builder1", "spellchecker") + pipeline.connect("spellchecker.replies", "prompt_builder3") + pipeline.connect("prompt_builder3", "retriever.query") + pipeline.connect("prompt_builder3", "ranker.query") + pipeline.connect("retriever.documents", "ranker.documents") + pipeline.connect("ranker.documents", "prompt_builder2.documents") + pipeline.connect("prompt_builder3", "prompt_builder2.question") + pipeline.connect("prompt_builder2", "llm") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"prompt_builder1": {"question": "Wha i Acromegaly?"}}, + expected_outputs={ + "llm": { + "replies": ["This is a reply"], + "meta": {"meta_key": "meta_value"}, + }, + "spellchecker": {"meta": {"meta_key": "meta_value"}}, + }, + expected_run_order=[ + "prompt_builder1", + "spellchecker", + "prompt_builder3", + "retriever", + "ranker", + "prompt_builder2", + "llm", + ], + ) + ], + ) + + +@given( + "a pipeline that is linear with conditional branching and multiple joins", + target_fixture="pipeline_data", +) +def that_is_linear_with_conditional_branching_and_multiple_joins(): + pipeline = AsyncPipeline() + + @component + class FakeRouter: + @component.output_types(LEGIT=str, INJECTION=str) + def run(self, query: str): + if "injection" in query: + return {"INJECTION": query} + return {"LEGIT": query} + + @component + class FakeEmbedder: + @component.output_types(embeddings=List[float]) + def run(self, text: str): + return {"embeddings": [1.0, 2.0, 3.0]} + + @component + class FakeRanker: + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document]): + return {"documents": documents} + + @component + class FakeRetriever: + @component.output_types(documents=List[Document]) + def run(self, query: str): + if "injection" in query: + return {"documents": []} + return {"documents": [Document(content="This is a document")]} + + @component + class FakeEmbeddingRetriever: + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float]): + return {"documents": [Document(content="This is another document")]} + + pipeline.add_component(name="router", instance=FakeRouter()) + pipeline.add_component(name="text_embedder", instance=FakeEmbedder()) + pipeline.add_component(name="retriever", instance=FakeEmbeddingRetriever()) + pipeline.add_component(name="emptyretriever", instance=FakeRetriever()) + pipeline.add_component(name="joinerfinal", instance=DocumentJoiner()) + pipeline.add_component(name="joinerhybrid", instance=DocumentJoiner()) + pipeline.add_component(name="ranker", instance=FakeRanker()) + pipeline.add_component(name="bm25retriever", instance=FakeRetriever()) + + pipeline.connect("router.INJECTION", "emptyretriever.query") + pipeline.connect("router.LEGIT", "text_embedder.text") + pipeline.connect("text_embedder", "retriever.query_embedding") + pipeline.connect("router.LEGIT", "ranker.query") + pipeline.connect("router.LEGIT", "bm25retriever.query") + pipeline.connect("bm25retriever", "joinerhybrid.documents") + pipeline.connect("retriever", "joinerhybrid.documents") + pipeline.connect("joinerhybrid.documents", "ranker.documents") + pipeline.connect("ranker", "joinerfinal.documents") + pipeline.connect("emptyretriever", "joinerfinal.documents") + + return ( + pipeline, + [ + PipelineRunData( + inputs={"router": {"query": "I'm a legit question"}}, + expected_outputs={ + "joinerfinal": { + "documents": [ + Document(content="This is a document"), + Document(content="This is another document"), + ] + } + }, + expected_run_order=[ + "router", + "text_embedder", + "bm25retriever", + "retriever", + "joinerhybrid", + "ranker", + "joinerfinal", + ], + ), + PipelineRunData( + inputs={"router": {"query": "I'm a nasty prompt injection"}}, + expected_outputs={"joinerfinal": {"documents": []}}, + expected_run_order=["router", "emptyretriever", "joinerfinal"], + ), + ], + ) + + +@given("a pipeline that is a simple agent", target_fixture="pipeline_data") +def that_is_a_simple_agent(): + search_message_template = """ + Given these web search results: + + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + Be as brief as possible, max one sentence. + Answer the question: {{search_query}} + """ + + react_message_template = """ + Solve a question answering task with interleaving Thought, Action, Observation steps. + + Thought reasons about the current situation + + Action can be: + google_search - Searches Google for the exact concept/entity (given in square brackets) and returns the results for you to use + finish - Returns the final answer (given in square brackets) and finishes the task + + Observation summarizes the Action outcome and helps in formulating the next + Thought in Thought, Action, Observation interleaving triplet of steps. + + After each Observation, provide the next Thought and next Action. + Don't execute multiple steps even though you know the answer. + Only generate Thought and Action, never Observation, you'll get Observation from Action. + Follow the pattern in the example below. + + Example: + ########################### + Question: Which magazine was started first Arthur’s Magazine or First for Women? + Thought: I need to search Arthur’s Magazine and First for Women, and find which was started + first. + Action: google_search[When was 'Arthur’s Magazine' started?] + Observation: Arthur’s Magazine was an American literary periodical ˘ + published in Philadelphia and founded in 1844. Edited by Timothy Shay Arthur, it featured work by + Edgar A. Poe, J.H. Ingraham, Sarah Josepha Hale, Thomas G. Spear, and others. In May 1846 + it was merged into Godey’s Lady’s Book. + Thought: Arthur’s Magazine was started in 1844. I need to search First for Women founding date next + Action: google_search[When was 'First for Women' magazine started?] + Observation: First for Women is a woman’s magazine published by Bauer Media Group in the + USA. The magazine was started in 1989. It is based in Englewood Cliffs, New Jersey. In 2011 + the circulation of the magazine was 1,310,696 copies. + Thought: First for Women was started in 1989. 1844 (Arthur’s Magazine) ¡ 1989 (First for + Women), so Arthur’s Magazine was started first. + Action: finish[Arthur’s Magazine] + ############################ + + Let's start, the question is: {{query}} + + Thought: + """ + + routes = [ + { + "condition": "{{'search' in tool_id_and_param[0]}}", + "output": "{{tool_id_and_param[1]}}", + "output_name": "search", + "output_type": str, + }, + { + "condition": "{{'finish' in tool_id_and_param[0]}}", + "output": "{{tool_id_and_param[1]}}", + "output_name": "finish", + "output_type": str, + }, + ] + + @component + class FakeThoughtActionOpenAIChatGenerator: + run_counter = 0 + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + if self.run_counter == 0: + self.run_counter += 1 + return { + "replies": [ + ChatMessage.from_assistant( + "thinking\n Action: google_search[What is taller, Eiffel Tower or Leaning Tower of Pisa]\n" + ) + ] + } + + return { + "replies": [ + ChatMessage.from_assistant( + "thinking\n Action: finish[Eiffel Tower]\n" + ) + ] + } + + @component + class FakeConclusionOpenAIChatGenerator: + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + return { + "replies": [ + ChatMessage.from_assistant("Tower of Pisa is 55 meters tall\n") + ] + } + + @component + class FakeSerperDevWebSearch: + @component.output_types(documents=List[Document]) + def run(self, query: str): + return { + "documents": [ + Document(content="Eiffel Tower is 300 meters tall"), + Document(content="Tower of Pisa is 55 meters tall"), + ] + } + + # main part + pipeline = AsyncPipeline() + pipeline.add_component("main_input", BranchJoiner(List[ChatMessage])) + pipeline.add_component("prompt_builder", ChatPromptBuilder(variables=["query"])) + pipeline.add_component("llm", FakeThoughtActionOpenAIChatGenerator()) + + @component + class ToolExtractor: + @component.output_types(output=List[str]) + def run(self, messages: List[ChatMessage]): + prompt: str = messages[-1].content + lines = prompt.strip().split("\n") + for line in reversed(lines): + pattern = r"Action:\s*(\w+)\[(.*?)\]" + + match = re.search(pattern, line) + if match: + action_name = match.group(1) + parameter = match.group(2) + return {"output": [action_name, parameter]} + return {"output": [None, None]} + + pipeline.add_component("tool_extractor", ToolExtractor()) + + @component + class PromptConcatenator: + def __init__(self, suffix: str = ""): + self._suffix = suffix + + @component.output_types(output=List[ChatMessage]) + def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]): + content = current_prompt[-1].content + replies[-1].content + self._suffix + return {"output": [ChatMessage.from_user(content)]} + + @component + class SearchOutputAdapter: + @component.output_types(output=List[ChatMessage]) + def run(self, replies: List[ChatMessage]): + content = f"Observation: {replies[-1].content}\n" + return {"output": [ChatMessage.from_assistant(content)]} + + pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator()) + + pipeline.add_component("router", ConditionalRouter(routes)) + pipeline.add_component("router_search", FakeSerperDevWebSearch()) + pipeline.add_component( + "search_prompt_builder", + ChatPromptBuilder(variables=["documents", "search_query"]), + ) + pipeline.add_component("search_llm", FakeConclusionOpenAIChatGenerator()) + + pipeline.add_component("search_output_adapter", SearchOutputAdapter()) + pipeline.add_component( + "prompt_concatenator_after_observation", + PromptConcatenator(suffix="\nThought: "), + ) + + # main + pipeline.connect("main_input", "prompt_builder.template") + pipeline.connect("prompt_builder.prompt", "llm.messages") + pipeline.connect("llm.replies", "prompt_concatenator_after_action.replies") + + # tools + pipeline.connect( + "prompt_builder.prompt", "prompt_concatenator_after_action.current_prompt" + ) + pipeline.connect("prompt_concatenator_after_action", "tool_extractor.messages") + + pipeline.connect("tool_extractor", "router") + pipeline.connect("router.search", "router_search.query") + pipeline.connect("router_search.documents", "search_prompt_builder.documents") + pipeline.connect("router.search", "search_prompt_builder.search_query") + pipeline.connect("search_prompt_builder.prompt", "search_llm.messages") + + pipeline.connect("search_llm.replies", "search_output_adapter.replies") + pipeline.connect( + "search_output_adapter", "prompt_concatenator_after_observation.replies" + ) + pipeline.connect( + "prompt_concatenator_after_action", + "prompt_concatenator_after_observation.current_prompt", + ) + pipeline.connect("prompt_concatenator_after_observation", "main_input") + + search_message = [ChatMessage.from_user(search_message_template)] + messages = [ChatMessage.from_user(react_message_template)] + question = "which tower is taller: eiffel tower or tower of pisa?" + + return pipeline, [ + PipelineRunData( + inputs={ + "main_input": {"value": messages}, + "prompt_builder": {"query": question}, + "search_prompt_builder": {"template": search_message}, + }, + expected_outputs={"router": {"finish": "Eiffel Tower"}}, + expected_run_order=[ + "main_input", + "prompt_builder", + "llm", + "prompt_concatenator_after_action", + "tool_extractor", + "router", + "router_search", + "search_prompt_builder", + "search_llm", + "search_output_adapter", + "prompt_concatenator_after_observation", + "main_input", + "prompt_builder", + "llm", + "prompt_concatenator_after_action", + "tool_extractor", + "router", + ], + ) + ] + + +@given( + "a pipeline that has a variadic component that receives partial inputs", + target_fixture="pipeline_data", +) +def that_has_a_variadic_component_that_receives_partial_inputs(): + @component + class ConditionalDocumentCreator: + def __init__(self, content: str): + self._content = content + + @component.output_types(documents=List[Document], noop=None) + def run(self, create_document: bool = False): + if create_document: + return { + "documents": [Document(id=self._content, content=self._content)] + } + return {"noop": None} + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component( + "first_creator", ConditionalDocumentCreator(content="First document") + ) + pipeline.add_component( + "second_creator", ConditionalDocumentCreator(content="Second document") + ) + pipeline.add_component( + "third_creator", ConditionalDocumentCreator(content="Third document") + ) + pipeline.add_component("documents_joiner", DocumentJoiner()) + + pipeline.connect("first_creator.documents", "documents_joiner.documents") + pipeline.connect("second_creator.documents", "documents_joiner.documents") + pipeline.connect("third_creator.documents", "documents_joiner.documents") + + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "first_creator": {"create_document": True}, + "third_creator": {"create_document": True}, + }, + expected_outputs={ + "second_creator": {"noop": None}, + "documents_joiner": { + "documents": [ + Document(id="First document", content="First document"), + Document(id="Third document", content="Third document"), + ] + }, + }, + expected_run_order=[ + "first_creator", + "second_creator", + "third_creator", + "documents_joiner", + ], + ), + PipelineRunData( + inputs={ + "first_creator": {"create_document": True}, + "second_creator": {"create_document": True}, + }, + expected_outputs={ + "third_creator": {"noop": None}, + "documents_joiner": { + "documents": [ + Document(id="First document", content="First document"), + Document(id="Second document", content="Second document"), + ] + }, + }, + expected_run_order=[ + "first_creator", + "second_creator", + "third_creator", + "documents_joiner", + ], + ), + ], + ) + + +@given( + "a pipeline that has an answer joiner variadic component", + target_fixture="pipeline_data", +) +def that_has_an_answer_joiner_variadic_component(): + query = "What's Natural Language Processing?" + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("answer_builder_1", AnswerBuilder()) + pipeline.add_component("answer_builder_2", AnswerBuilder()) + pipeline.add_component("answer_joiner", AnswerJoiner()) + + pipeline.connect("answer_builder_1.answers", "answer_joiner") + pipeline.connect("answer_builder_2.answers", "answer_joiner") + + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "answer_builder_1": { + "query": query, + "replies": ["This is a test answer"], + }, + "answer_builder_2": { + "query": query, + "replies": ["This is a second test answer"], + }, + }, + expected_outputs={ + "answer_joiner": { + "answers": [ + GeneratedAnswer( + data="This is a test answer", + query="What's Natural Language Processing?", + documents=[], + meta={}, + ), + GeneratedAnswer( + data="This is a second test answer", + query="What's Natural Language Processing?", + documents=[], + meta={}, + ), + ] + } + }, + expected_run_order=[ + "answer_builder_1", + "answer_builder_2", + "answer_joiner", + ], + ) + ], + ) + + +@given( + "a pipeline that is linear and a component in the middle receives optional input from other components and input from the user", + target_fixture="pipeline_data", +) +def that_is_linear_and_a_component_in_the_middle_receives_optional_input_from_other_components_and_input_from_the_user(): + @component + class QueryMetadataExtractor: + @component.output_types(filters=Dict[str, str]) + def run(self, prompt: str): + metadata = json.loads(prompt) + filters = [] + for key, value in metadata.items(): + filters.append( + {"field": f"meta.{key}", "operator": "==", "value": value} + ) + + return {"filters": {"operator": "AND", "conditions": filters}} + + documents = [ + Document( + content="some publication about Alzheimer prevention research done over 2023 patients study", + meta={"year": 2022, "disease": "Alzheimer", "author": "Michael Butter"}, + id="doc1", + ), + Document( + content="some text about investigation and treatment of Alzheimer disease", + meta={"year": 2023, "disease": "Alzheimer", "author": "John Bread"}, + id="doc2", + ), + Document( + content="A study on the effectiveness of new therapies for Parkinson's disease", + meta={"year": 2022, "disease": "Parkinson", "author": "Alice Smith"}, + id="doc3", + ), + Document( + content="An overview of the latest research on the genetics of Parkinson's disease and its implications for treatment", + meta={"year": 2023, "disease": "Parkinson", "author": "David Jones"}, + id="doc4", + ), + ] + document_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus") + document_store.write_documents( + documents=documents, policy=DuplicatePolicy.OVERWRITE + ) + + pipeline = AsyncPipeline() + pipeline.add_component( + instance=PromptBuilder('{"disease": "Alzheimer", "year": 2023}'), name="builder" + ) + pipeline.add_component(instance=QueryMetadataExtractor(), name="metadata_extractor") + pipeline.add_component( + instance=InMemoryBM25Retriever(document_store=document_store), name="retriever" + ) + pipeline.add_component(instance=DocumentJoiner(), name="document_joiner") + + pipeline.connect("builder.prompt", "metadata_extractor.prompt") + pipeline.connect("metadata_extractor.filters", "retriever.filters") + pipeline.connect("retriever.documents", "document_joiner.documents") + + query = "publications 2023 Alzheimer's disease" + + return ( + pipeline, + [ + PipelineRunData( + inputs={"retriever": {"query": query}}, + expected_outputs={ + "document_joiner": { + "documents": [ + Document( + content="some text about investigation and treatment of Alzheimer disease", + meta={ + "year": 2023, + "disease": "Alzheimer", + "author": "John Bread", + }, + id="doc2", + score=3.324112496100923, + ) + ] + } + }, + expected_run_order=[ + "builder", + "metadata_extractor", + "retriever", + "document_joiner", + ], + ) + ], + ) + + +@given( + "a pipeline that has a cycle that would get it stuck", + target_fixture="pipeline_data", +) +def that_has_a_cycle_that_would_get_it_stuck(): + template = """ + You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n + Negative experience labels: + - Late delivery + - Rotten/spoilt item + - Bad Courier behavior + + Positive experience labels: + - Good courier behavior + - Thanks & appreciation + - Love message to courier + - Fast delivery + - Quality of products + + Create a JSON object as a response. The fields are: 'positive_experience', 'negative_experience'. + Assign at least one of the pre-defined labels to the given customer comment under positive and negative experience fields. + If the comment has a positive experience, list the label under 'positive_experience' field. + If the comments has a negative_experience, list it under the 'negative_experience' field. + Here is the comment:\n{{ comment }}\n. Just return the category names in the list. If there aren't any, return an empty list. + + {% if invalid_replies and error_message %} + You already created the following output in a previous attempt: {{ invalid_replies }} + However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message }} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} + """ + prompt_builder = PromptBuilder( + template=template, + required_variables=["comment", "invalid_replies", "error_message"], + ) + + @component + class FakeOutputValidator: + @component.output_types( + valid_replies=List[str], + invalid_replies=Optional[List[str]], + error_message=Optional[str], + ) + def run(self, replies: List[str]): + if not getattr(self, "called", False): + self.called = True + return { + "invalid_replies": ["This is an invalid reply"], + "error_message": "this is an error message", + } + return {"valid_replies": replies} + + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + return {"replies": ["This is a valid reply"]} + + llm = FakeGenerator() + validator = FakeOutputValidator() + + pipeline = AsyncPipeline(max_runs_per_component=1) + pipeline.add_component("prompt_builder", prompt_builder) + + pipeline.add_component("llm", llm) + pipeline.add_component("output_validator", validator) + + pipeline.connect("prompt_builder.prompt", "llm.prompt") + pipeline.connect("llm.replies", "output_validator.replies") + pipeline.connect( + "output_validator.invalid_replies", "prompt_builder.invalid_replies" + ) + + pipeline.connect("output_validator.error_message", "prompt_builder.error_message") + + comment = "I loved the quality of the meal but the courier was rude" + return ( + pipeline, + [PipelineRunData(inputs={"prompt_builder": {"comment": comment}})], + ) + + +@given("a pipeline that has a loop in the middle", target_fixture="pipeline_data") +def that_has_a_loop_in_the_middle(): + @component + class FakeGenerator: + @component.output_types(replies=List[str]) + def run(self, prompt: str): + replies = [] + if getattr(self, "first_run", True): + self.first_run = False + replies.append("No answer") + else: + replies.append("42") + return {"replies": replies} + + @component + class PromptCleaner: + @component.output_types(clean_prompt=str) + def run(self, prompt: str): + return {"clean_prompt": prompt.strip()} + + routes = [ + { + "condition": "{{ 'No answer' in replies }}", + "output": "{{ replies }}", + "output_name": "invalid_replies", + "output_type": List[str], + }, + { + "condition": "{{ 'No answer' not in replies }}", + "output": "{{ replies }}", + "output_name": "valid_replies", + "output_type": List[str], + }, + ] + + pipeline = AsyncPipeline(max_runs_per_component=20) + pipeline.add_component("prompt_cleaner", PromptCleaner()) + pipeline.add_component( + "prompt_builder", + PromptBuilder(template="", variables=["question", "invalid_replies"]), + ) + pipeline.add_component("llm", FakeGenerator()) + pipeline.add_component("answer_validator", ConditionalRouter(routes=routes)) + pipeline.add_component("answer_builder", AnswerBuilder()) + + pipeline.connect("prompt_cleaner.clean_prompt", "prompt_builder.template") + pipeline.connect("prompt_builder.prompt", "llm.prompt") + pipeline.connect("llm.replies", "answer_validator.replies") + pipeline.connect( + "answer_validator.invalid_replies", "prompt_builder.invalid_replies" + ) + pipeline.connect("answer_validator.valid_replies", "answer_builder.replies") + + question = "What is the answer?" + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "prompt_cleaner": {"prompt": "Random template"}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + }, + expected_outputs={ + "answer_builder": { + "answers": [ + GeneratedAnswer(data="42", query=question, documents=[]) + ] + } + }, + expected_run_order=[ + "prompt_cleaner", + "prompt_builder", + "llm", + "answer_validator", + "prompt_builder", + "llm", + "answer_validator", + "answer_builder", + ], + ) + ], + ) + + +@given( + "a pipeline that has variadic component that receives a conditional input", + target_fixture="pipeline_data", +) +def that_has_variadic_component_that_receives_a_conditional_input(): + pipe = AsyncPipeline(max_runs_per_component=1) + routes = [ + { + "condition": "{{ documents|length > 1 }}", + "output": "{{ documents }}", + "output_name": "long", + "output_type": List[Document], + }, + { + "condition": "{{ documents|length <= 1 }}", + "output": "{{ documents }}", + "output_name": "short", + "output_type": List[Document], + }, + ] + + @component + class NoOp: + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + return {"documents": documents} + + @component + class CommaSplitter: + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + res = [] + current_id = 0 + for doc in documents: + for split in doc.content.split(","): + res.append(Document(content=split, id=str(current_id))) + current_id += 1 + return {"documents": res} + + pipe.add_component("conditional_router", ConditionalRouter(routes, unsafe=True)) + pipe.add_component( + "empty_lines_cleaner", + DocumentCleaner( + remove_empty_lines=True, remove_extra_whitespaces=False, keep_id=True + ), + ) + pipe.add_component("comma_splitter", CommaSplitter()) + pipe.add_component("document_cleaner", DocumentCleaner(keep_id=True)) + pipe.add_component("document_joiner", DocumentJoiner()) + + pipe.add_component("noop2", NoOp()) + pipe.add_component("noop3", NoOp()) + + pipe.connect("noop2", "noop3") + pipe.connect("noop3", "conditional_router") + + pipe.connect("conditional_router.long", "empty_lines_cleaner") + pipe.connect("empty_lines_cleaner", "document_joiner") + + pipe.connect("comma_splitter", "document_cleaner") + pipe.connect("document_cleaner", "document_joiner") + pipe.connect("comma_splitter", "document_joiner") + + document = Document( + id="1000", + content="This document has so many, sentences. Like this one, or this one. Or even this other one.", + ) + + return pipe, [ + PipelineRunData( + inputs={ + "noop2": {"documents": [document]}, + "comma_splitter": {"documents": [document]}, + }, + expected_outputs={ + "conditional_router": { + "short": [ + Document( + id="1000", + content="This document has so many, sentences. Like this one, or this one. Or even this other one.", + ) + ] + }, + "document_joiner": { + "documents": [ + Document(id="0", content="This document has so many"), + Document(id="1", content=" sentences. Like this one"), + Document( + id="2", content=" or this one. Or even this other one." + ), + ] + }, + }, + expected_run_order=[ + "comma_splitter", + "noop2", + "document_cleaner", + "noop3", + "conditional_router", + "document_joiner", + ], + ), + PipelineRunData( + inputs={ + "noop2": {"documents": [document, document]}, + "comma_splitter": {"documents": [document, document]}, + }, + expected_outputs={ + "document_joiner": { + "documents": [ + Document(id="0", content="This document has so many"), + Document(id="1", content=" sentences. Like this one"), + Document( + id="2", content=" or this one. Or even this other one." + ), + Document(id="3", content="This document has so many"), + Document(id="4", content=" sentences. Like this one"), + Document( + id="5", content=" or this one. Or even this other one." + ), + Document( + id="1000", + content="This document has so many, sentences. Like this one, or this one. Or even this other one.", + ), + ] + } + }, + expected_run_order=[ + "comma_splitter", + "noop2", + "document_cleaner", + "noop3", + "conditional_router", + "empty_lines_cleaner", + "document_joiner", + ], + ), + ] + + +@given( + "a pipeline that has a string variadic component", target_fixture="pipeline_data" +) +def that_has_a_string_variadic_component(): + string_1 = "What's Natural Language Processing?" + string_2 = "What's is life?" + + pipeline = AsyncPipeline() + pipeline.add_component("prompt_builder_1", PromptBuilder("Builder 1: {{query}}")) + pipeline.add_component("prompt_builder_2", PromptBuilder("Builder 2: {{query}}")) + pipeline.add_component("string_joiner", StringJoiner()) + + pipeline.connect("prompt_builder_1.prompt", "string_joiner.strings") + pipeline.connect("prompt_builder_2.prompt", "string_joiner.strings") + + return ( + pipeline, + [ + PipelineRunData( + inputs={ + "prompt_builder_1": {"query": string_1}, + "prompt_builder_2": {"query": string_2}, + }, + expected_outputs={ + "string_joiner": { + "strings": [ + "Builder 1: What's Natural Language Processing?", + "Builder 2: What's is life?", + ] + } + }, + expected_run_order=[ + "prompt_builder_1", + "prompt_builder_2", + "string_joiner", + ], + ) + ], + ) diff --git a/test/core/pipeline/test_async_pipeline.py b/test/core/pipeline/test_async_pipeline.py new file mode 100644 index 00000000..19a99ddf --- /dev/null +++ b/test/core/pipeline/test_async_pipeline.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +import pytest + +from haystack import component +from haystack.testing.sample_components import AddFixedValue +from haystack_experimental.core import AsyncPipeline, run_async_pipeline + + +@component +class AsyncDoubleWithOriginal: + """ + Doubles the input value and returns the original value as well. + """ + + def __init__(self) -> None: + self.async_executed = False + + @component.output_types(value=int, original=int) + def run(self, value: int): + raise NotImplementedError() + + @component.output_types(value=int, original=int) + async def run_async(self, value: int): + self.async_executed = True + return {"value": value * 2, "original": value} + + +@pytest.mark.asyncio +async def test_async_pipeline(): + pipeline = AsyncPipeline() + pipeline.add_component("first_addition", AddFixedValue(add=2)) + pipeline.add_component("second_addition", AddFixedValue()) + pipeline.add_component("double", AsyncDoubleWithOriginal()) + pipeline.connect("first_addition", "double") + pipeline.connect("double.value", "second_addition") + + outputs = {} + # since enumerate doesn't work with async generators + expected_intermediate_outputs = [ + {"first_addition": {"result": 5}}, + {"double": {"value": 10, "original": 5}}, + {"second_addition": {"result": 11}}, + ] + + outputs = [o async for o in pipeline.run({"first_addition": {"value": 3}})] + intermediate_outputs = outputs[:-1] + final_output = outputs[-1] + + assert expected_intermediate_outputs == intermediate_outputs + assert final_output == { + "double": {"original": 5}, + "second_addition": {"result": 11}, + } + assert pipeline.get_component("double").async_executed is True + pipeline.get_component("double").async_executed = False + + other_final_outputs = await run_async_pipeline( + pipeline, + {"first_addition": {"value": 3}}, + include_outputs_from={"double", "second_addition", "first_addition"}, + ) + assert other_final_outputs == { + "first_addition": {"result": 5}, + "double": {"value": 10, "original": 5}, + "second_addition": {"result": 11}, + } + assert pipeline.get_component("double").async_executed is True diff --git a/test/document_stores/__init__.py b/test/document_stores/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/document_stores/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/document_stores/opensearch/__init__.py b/test/document_stores/opensearch/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/document_stores/opensearch/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/document_stores/opensearch/test_document_store.py b/test/document_stores/opensearch/test_document_store.py new file mode 100644 index 00000000..6fdeea93 --- /dev/null +++ b/test/document_stores/opensearch/test_document_store.py @@ -0,0 +1,1513 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import random +from typing import List +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import Document +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.utils.auth import Secret +from opensearchpy.exceptions import RequestError + +from haystack_experimental.document_stores.opensearch import OpenSearchDocumentStore +from haystack_integrations.document_stores.opensearch.auth import AWSAuth +from haystack_integrations.document_stores.opensearch.document_store import ( + DEFAULT_MAX_CHUNK_BYTES, +) + + +OPENSEARCH_URL = "https://localhost:9200" + + +def opensearch_backend_active() -> bool: + try: + store = OpenSearchDocumentStore( + hosts=[OPENSEARCH_URL], + index="test", + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + create_index=False, + ) + store._ensure_initialized() + return True + except Exception as e: + return False + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_to_dict(_mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some hosts") + res = document_store.to_dict() + assert res == { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": {"type": "keyword"}, + "match_mapping_type": "string", + } + } + ], + "properties": { + "content": {"type": "text"}, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_from_dict(_mock_opensearch_client): + data = { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + "max_chunk_bytes": 1000, + "embedding_dim": 1536, + "create_index": False, + "return_embedding": True, + "aws_service": "es", + "http_auth": ("admin", "admin"), + "use_ssl": True, + "verify_certs": True, + "timeout": 60, + }, + } + document_store = OpenSearchDocumentStore.from_dict(data) + assert document_store._hosts == "some hosts" + assert document_store._index == "default" + assert document_store._max_chunk_bytes == 1000 + assert document_store._embedding_dim == 1536 + assert document_store._method is None + assert document_store._mappings == { + "properties": { + "embedding": {"type": "knn_vector", "index": True, "dimension": 1536}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "match_mapping_type": "string", + "mapping": {"type": "keyword"}, + } + } + ], + } + assert document_store._settings == {"index.knn": True} + assert document_store._return_embedding is True + assert document_store._create_index is False + assert document_store._http_auth == ("admin", "admin") + assert document_store._use_ssl is True + assert document_store._verify_certs is True + assert document_store._timeout == 60 + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_init_is_lazy(_mock_opensearch_client): + OpenSearchDocumentStore(hosts="testhost") + _mock_opensearch_client.assert_not_called() + + +@patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") +def test_get_default_mappings(_mock_opensearch_client): + store = OpenSearchDocumentStore( + hosts="testhost", embedding_dim=1536, method={"name": "hnsw"} + ) + assert store._mappings["properties"]["embedding"] == { + "type": "knn_vector", + "index": True, + "dimension": 1536, + "method": {"name": "hnsw"}, + } + + +class TestAuth: + @pytest.fixture(autouse=True) + def mock_boto3_session(self): + with patch("boto3.Session") as mock_client: + yield mock_client + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore( + hosts="testhost", http_auth=("user", "pw") + ) + document_store._ensure_initialized() + assert document_store._client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ("user", "pw") + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_init_without_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="testhost") + document_store._ensure_initialized() + assert document_store._client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] is None + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_init_aws_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore( + hosts="testhost", + http_auth=AWSAuth(aws_region_name=Secret.from_token("dummy-region")), + use_ssl=True, + verify_certs=True, + ) + document_store._ensure_initialized() + assert document_store._client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": ["user", "pw"], + "use_ssl": True, + "verify_certs": True, + }, + } + ) + document_store._ensure_initialized() + assert document_store._client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pw"] + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_aws_auth( + self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch + ): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": {}, + }, + "use_ssl": True, + "verify_certs": True, + }, + } + ) + document_store._ensure_initialized() + assert document_store._client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore( + hosts="some hosts", http_auth=("user", "pw") + ) + res = document_store.to_dict() + assert res == { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": {"type": "keyword"}, + "match_mapping_type": "string", + } + } + ], + "properties": { + "content": {"type": "text"}, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": ("user", "pw"), + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + @patch("haystack_experimental.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_aws_auth( + self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch + ): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore( + hosts="some hosts", http_auth=AWSAuth() + ) + res = document_store.to_dict() + assert res == { + "type": "haystack_experimental.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": {"type": "keyword"}, + "match_mapping_type": "string", + } + } + ], + "properties": { + "content": {"type": "text"}, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, + "aws_service": "es", + }, + }, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + +@pytest.mark.skipif( + not opensearch_backend_active(), reason="OpenSearch backend is not active" +) +@pytest.mark.integration +class TestDocumentStore(DocumentStoreBaseTests): + """ + Common test cases will be provided by `DocumentStoreBaseTests` but + you can add more to this class. + """ + + @pytest.fixture + def document_store(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + yield store + store._ensure_initialized() + assert store._client + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + + @pytest.fixture + def document_store_readonly(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + create_index=False, + ) + store._ensure_initialized() + assert store._client + store._client.cluster.put_settings( + body={"transient": {"action.auto_create_index": False}} + ) + yield store + store._client.cluster.put_settings( + body={"transient": {"action.auto_create_index": True}} + ) + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + + @pytest.fixture + def document_store_embedding_dim_4(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + yield store + store._ensure_initialized() + assert store._client + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + + def assert_documents_are_equal( + self, received: List[Document], expected: List[Document] + ): + """ + The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set. + We don't want to compare the score, so we set it to None before comparing the documents. + """ + received_meta = [] + for doc in received: + r = { + "number": doc.meta.get("number"), + "name": doc.meta.get("name"), + } + received_meta.append(r) + + expected_meta = [] + for doc in expected: + r = { + "number": doc.meta.get("number"), + "name": doc.meta.get("name"), + } + expected_meta.append(r) + for doc in received: + doc.score = None + + super().assert_documents_are_equal(received, expected) + + def test_write_documents(self, document_store: OpenSearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + with pytest.raises(DuplicateDocumentError): + document_store.write_documents(docs, DuplicatePolicy.FAIL) + + def test_write_documents_readonly( + self, document_store_readonly: OpenSearchDocumentStore + ): + docs = [Document(id="1")] + with pytest.raises(DocumentStoreError, match="index_not_found_exception"): + document_store_readonly.write_documents(docs) + + def test_create_index(self, document_store_readonly: OpenSearchDocumentStore): + document_store_readonly.create_index() + assert document_store_readonly._client.indices.exists( + index=document_store_readonly._index + ) + + def test_bm25_retrieval(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = document_store._bm25_retrieval("functional", top_k=3) + assert len(res) == 3 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + + def test_bm25_retrieval_pagination(self, document_store: OpenSearchDocumentStore): + """ + Test that handling of pagination works as expected, when the matching documents are > 10. + """ + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + Document(content="Java is an object oriented programming language"), + Document(content="Javascript is a programming language"), + Document(content="Typescript is a programming language"), + Document(content="C is a programming language"), + ] + ) + + res = document_store._bm25_retrieval("programming", top_k=11) + assert len(res) == 11 + assert all("programming" in doc.content for doc in res) + + def test_bm25_retrieval_all_terms_must_match( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = document_store._bm25_retrieval( + "functional Haskell", top_k=3, all_terms_must_match=True + ) + assert len(res) == 1 + assert "Haskell is a functional programming language" in res[0].content + + def test_bm25_retrieval_all_terms_must_match_false( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = document_store._bm25_retrieval( + "functional Haskell", top_k=10, all_terms_must_match=False + ) + assert len(res) == 5 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + assert "functional" in res[3].content + assert "functional" in res[4].content + + def test_bm25_retrieval_with_fuzziness( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + query_with_typo = "functinal" + # Query without fuzziness to search for the exact match + res = document_store._bm25_retrieval(query_with_typo, top_k=3, fuzziness="0") + # Nothing is found as the query contains a typo + assert res == [] + + # Query with fuzziness with the same query + res = document_store._bm25_retrieval(query_with_typo, top_k=3, fuzziness="1") + assert len(res) == 3 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + + def test_bm25_retrieval_with_filters(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + res = document_store._bm25_retrieval( + "programming", + top_k=10, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 5 + retrieved_ids = sorted([doc.id for doc in res]) + assert retrieved_ids == ["1", "2", "3", "4", "5"] + + def test_bm25_retrieval_with_custom_query( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + custom_query = { + "query": { + "function_score": { + "query": { + "bool": { + "must": {"match": {"content": "$query"}}, + "filter": "$filters", + } + }, + "field_value_factor": { + "field": "likes", + "factor": 0.1, + "modifier": "log1p", + "missing": 0, + }, + } + } + } + + res = document_store._bm25_retrieval( + "functional", + top_k=3, + custom_query=custom_query, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 3 + assert "1" == res[0].id + assert "2" == res[1].id + assert "3" == res[2].id + + def test_embedding_retrieval( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] + ), + ] + document_store_embedding_dim_4.write_documents(docs) + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} + ) + assert len(results) == 2 + assert results[0].content == "Most similar document" + assert results[1].content == "2nd best document" + + def test_embedding_retrieval_with_filters( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + # we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib + # TODO: remove top_k=3, when efficient filtering is supported for nmslib + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3, filters=filters + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + + def test_embedding_retrieval_pagination( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + """ + Test that handling of pagination works as expected, when the matching documents are > 10. + """ + + docs = [ + Document( + content=f"Document {i}", embedding=[random.random() for _ in range(4)] + ) # noqa: S311 + for i in range(20) + ] + + document_store_embedding_dim_4.write_documents(docs) + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=11, filters={} + ) + assert len(results) == 11 + + def test_embedding_retrieval_with_custom_query( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + custom_query = { + "query": { + "bool": { + "must": [ + {"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}} + ], + "filter": "$filters", + } + } + } + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], + top_k=1, + filters=filters, + custom_query=custom_query, + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + + def test_embedding_retrieval_query_documents_different_embedding_sizes( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + """ + Test that the retrieval fails if the query embedding and the documents have different embedding sizes. + """ + docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])] + document_store_embedding_dim_4.write_documents(docs) + + with pytest.raises(RequestError): + document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1] + ) + + def test_write_documents_different_embedding_sizes_fail( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + """ + Test that write_documents fails if the documents have different embedding sizes. + """ + docs = [ + Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), + Document(content="Hello world", embedding=[0.1, 0.2]), + ] + + with pytest.raises(DocumentStoreError): + document_store_embedding_dim_4.write_documents(docs) + + @patch("haystack_experimental.document_stores.opensearch.document_store.bulk") + def test_write_documents_with_badly_formatted_bulk_errors( + self, mock_bulk, document_store + ): + error = {"some_key": "some_value"} + mock_bulk.return_value = ([], [error]) + + with pytest.raises(DocumentStoreError) as e: + document_store.write_documents([Document(content="Hello world")]) + e.match(f"{error}") + + @patch("haystack_experimental.document_stores.opensearch.document_store.bulk") + def test_write_documents_max_chunk_bytes(self, mock_bulk, document_store): + mock_bulk.return_value = (1, []) + document_store.write_documents([Document(content="Hello world")]) + + assert mock_bulk.call_args.kwargs["max_chunk_bytes"] == DEFAULT_MAX_CHUNK_BYTES + + @pytest.fixture + def document_store_no_embbding_returned(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + return_embedding=False, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + store._ensure_initialized() + yield store + + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + + def test_embedding_retrieval_but_dont_return_embeddings_for_embedding_retrieval( + self, document_store_no_embbding_returned: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] + ), + ] + document_store_no_embbding_returned.write_documents(docs) + results = document_store_no_embbding_returned._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} + ) + assert len(results) == 2 + assert results[0].embedding is None + + def test_embedding_retrieval_but_dont_return_embeddings_for_bm25_retrieval( + self, document_store_no_embbding_returned: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] + ), + ] + document_store_no_embbding_returned.write_documents(docs) + results = document_store_no_embbding_returned._bm25_retrieval( + "document", top_k=2 + ) + assert len(results) == 2 + assert results[0].embedding is None + + +# Needs to a separate class due to the fixtures requiring async. +@pytest.mark.skipif( + not opensearch_backend_active(), reason="OpenSearch backend is not active" +) +@pytest.mark.integration +class TestDocumentStoreAsync: + + @pytest.fixture + async def document_store(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + yield store + store._ensure_initialized() + assert store._client + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + await store._async_client.close() + + @pytest.fixture + async def document_store_readonly(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + create_index=False, + ) + store._ensure_initialized() + assert store._client + store._client.cluster.put_settings( + body={"transient": {"action.auto_create_index": False}} + ) + yield store + store._client.cluster.put_settings( + body={"transient": {"action.auto_create_index": True}} + ) + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + await store._async_client.close() + + @pytest.fixture + async def document_store_embedding_dim_4(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + yield store + store._ensure_initialized() + assert store._client + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + await store._async_client.close() + + @pytest.fixture + async def document_store_no_embbding_returned(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = [OPENSEARCH_URL] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + return_embedding=False, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + store._ensure_initialized() + yield store + store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + + @pytest.mark.asyncio + async def test_write_documents(self, document_store: OpenSearchDocumentStore): + docs = [Document(id="1")] + assert await document_store.write_documents_async([Document(id="2")]) == 1 + + @pytest.mark.asyncio + async def test_bm25_retrieval(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + res = await document_store._bm25_retrieval_async("functional", top_k=3) + assert len(res) == 3 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + + @pytest.mark.asyncio + async def test_bm25_retrieval_pagination( + self, document_store: OpenSearchDocumentStore + ): + """ + Test that handling of pagination works as expected, when the matching documents are > 10. + """ + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + Document(content="Java is an object oriented programming language"), + Document(content="Javascript is a programming language"), + Document(content="Typescript is a programming language"), + Document(content="C is a programming language"), + ] + ) + + res = await document_store._bm25_retrieval_async("programming", top_k=11) + assert len(res) == 11 + assert all("programming" in doc.content for doc in res) + + @pytest.mark.asyncio + async def test_bm25_retrieval_all_terms_must_match( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = await document_store._bm25_retrieval_async( + "functional Haskell", top_k=3, all_terms_must_match=True + ) + assert len(res) == 1 + assert "Haskell is a functional programming language" in res[0].content + + @pytest.mark.asyncio + async def test_bm25_retrieval_all_terms_must_match_false( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + + res = await document_store._bm25_retrieval_async( + "functional Haskell", top_k=10, all_terms_must_match=False + ) + assert len(res) == 5 + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content + assert "functional" in res[3].content + assert "functional" in res[4].content + + @pytest.mark.asyncio + async def test_bm25_retrieval_with_filters( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + res = await document_store._bm25_retrieval_async( + "programming", + top_k=10, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 5 + retrieved_ids = sorted([doc.id for doc in res]) + assert retrieved_ids == ["1", "2", "3", "4", "5"] + + @pytest.mark.asyncio + async def test_bm25_retrieval_with_custom_query( + self, document_store: OpenSearchDocumentStore + ): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + custom_query = { + "query": { + "function_score": { + "query": { + "bool": { + "must": {"match": {"content": "$query"}}, + "filter": "$filters", + } + }, + "field_value_factor": { + "field": "likes", + "factor": 0.1, + "modifier": "log1p", + "missing": 0, + }, + } + } + } + res = await document_store._bm25_retrieval_async( + "functional", + top_k=3, + custom_query=custom_query, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 3 + assert "1" == res[0].id + assert "2" == res[1].id + assert "3" == res[2].id + + @pytest.mark.asyncio + async def test_embedding_retrieval( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + results = await document_store_embedding_dim_4._embedding_retrieval_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} + ) + assert len(results) == 2 + assert results[0].content == "Most similar document" + assert results[1].content == "2nd best document" + + @pytest.mark.asyncio + async def test_embedding_retrieval_with_filters( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + + results = await document_store_embedding_dim_4._embedding_retrieval_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3, filters=filters + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + + @pytest.mark.asyncio + async def test_embedding_retrieval_with_custom_query( + self, document_store_embedding_dim_4: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + custom_query = { + "query": { + "bool": { + "must": [ + {"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}} + ], + "filter": "$filters", + } + } + } + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + + results = await document_store_embedding_dim_4._embedding_retrieval_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], + top_k=1, + filters=filters, + custom_query=custom_query, + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + + @pytest.mark.asyncio + async def test_embedding_retrieval_but_dont_return_embeddings_for_embedding_retrieval( + self, document_store_no_embbding_returned: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] + ), + ] + document_store_no_embbding_returned.write_documents(docs) + + results = await document_store_no_embbding_returned._embedding_retrieval_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} + ) + assert len(results) == 2 + assert results[0].embedding is None + + @pytest.mark.asyncio + async def test_count_documents(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document(content="test doc 1"), + Document(content="test doc 2"), + Document(content="test doc 3"), + ] + ) + assert await document_store.count_documents_async() == 3 + + @pytest.mark.asyncio + async def test_filter_documents(self, document_store: OpenSearchDocumentStore): + filterable_docs = [ + Document( + content=f"1", + meta={ + "number": -10, + }, + ), + Document( + content=f"2", + meta={ + "number": 100, + }, + ), + ] + await document_store.write_documents_async(filterable_docs) + result = await document_store.filter_documents_async( + filters={"field": "meta.number", "operator": "==", "value": 100} + ) + TestDocumentStore().assert_documents_are_equal( + result, [d for d in filterable_docs if d.meta.get("number") == 100] + ) + + @pytest.mark.asyncio + async def test_delete_documents(self, document_store: OpenSearchDocumentStore): + doc = Document(content="test doc") + await document_store.write_documents_async([doc]) + assert document_store.count_documents() == 1 + + await document_store.delete_documents_async([doc.id]) + assert await document_store.count_documents_async() == 0 diff --git a/test/document_stores/test_in_memory.py b/test/document_stores/test_in_memory.py new file mode 100644 index 00000000..d1c07ef4 --- /dev/null +++ b/test/document_stores/test_in_memory.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +import tempfile + +from haystack import Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack_experimental.document_stores.in_memory import InMemoryDocumentStore +from haystack.testing.document_store import DocumentStoreBaseTests + + +class TestMemoryDocumentStoreAsync(DocumentStoreBaseTests): # pylint: disable=R0904 + """ + Test InMemoryDocumentStore's specific features + """ + + @pytest.fixture + def tmp_dir(self): + with tempfile.TemporaryDirectory() as tmp_dir: + yield tmp_dir + + @pytest.fixture + def document_store(self) -> InMemoryDocumentStore: + return InMemoryDocumentStore(bm25_algorithm="BM25L") + + def test_to_dict(self): + store = InMemoryDocumentStore() + data = store.to_dict() + assert data == { + "type": "haystack_experimental.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": r"(?u)\b\w\w+\b", + "bm25_algorithm": "BM25L", + "bm25_parameters": {}, + "embedding_similarity_function": "dot_product", + "index": store.index, + }, + } + + def test_to_dict_with_custom_init_parameters(self): + store = InMemoryDocumentStore( + bm25_tokenization_regex="custom_regex", + bm25_algorithm="BM25Plus", + bm25_parameters={"key": "value"}, + embedding_similarity_function="cosine", + index="my_cool_index", + ) + data = store.to_dict() + assert data == { + "type": "haystack_experimental.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": "custom_regex", + "bm25_algorithm": "BM25Plus", + "bm25_parameters": {"key": "value"}, + "embedding_similarity_function": "cosine", + "index": "my_cool_index", + }, + } + + @patch("haystack.document_stores.in_memory.document_store.re") + def test_from_dict(self, mock_regex): + data = { + "type": "haystack_experimental.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": "custom_regex", + "bm25_algorithm": "BM25Plus", + "bm25_parameters": {"key": "value"}, + "index": "my_cool_index", + }, + } + store = InMemoryDocumentStore.from_dict(data) + mock_regex.compile.assert_called_with("custom_regex") + assert store.tokenizer + assert store.bm25_algorithm == "BM25Plus" + assert store.bm25_parameters == {"key": "value"} + assert store.index == "my_cool_index" + + @pytest.mark.asyncio + async def test_write_documents(self, document_store: InMemoryDocumentStore): + docs = [Document(id="1")] + assert await document_store.write_documents_async(docs) == 1 + with pytest.raises(DuplicateDocumentError): + await document_store.write_documents_async(docs) + + @pytest.mark.asyncio + async def test_count_documents(self, document_store: InMemoryDocumentStore): + await document_store.write_documents_async( + [ + Document(content="test doc 1"), + Document(content="test doc 2"), + Document(content="test doc 3"), + ] + ) + assert await document_store.count_documents_async() == 3 + + @pytest.mark.asyncio + async def test_filter_documents(self, document_store: InMemoryDocumentStore): + filterable_docs = [ + Document( + content=f"1", + meta={ + "number": -10, + }, + ), + Document( + content=f"2", + meta={ + "number": 100, + }, + ), + ] + await document_store.write_documents_async(filterable_docs) + result = await document_store.filter_documents_async( + filters={"field": "meta.number", "operator": "==", "value": 100} + ) + DocumentStoreBaseTests().assert_documents_are_equal( + result, [d for d in filterable_docs if d.meta.get("number") == 100] + ) + + @pytest.mark.asyncio + async def test_delete_documents(self, document_store: InMemoryDocumentStore): + doc = Document(content="test doc") + await document_store.write_documents_async([doc]) + assert document_store.count_documents() == 1 + + await document_store.delete_documents_async([doc.id]) + assert await document_store.count_documents_async() == 0 + + @pytest.mark.asyncio + async def test_bm25_retrieval(self, document_store: InMemoryDocumentStore): + # Tests if the bm25_retrieval method returns the correct document based on the input query. + docs = [ + Document(content="Hello world"), + Document(content="Haystack supports multiple languages"), + ] + await document_store.write_documents_async(docs) + results = await document_store.bm25_retrieval_async( + query="What languages?", top_k=1 + ) + assert len(results) == 1 + assert results[0].content == "Haystack supports multiple languages" + + @pytest.mark.asyncio + async def test_embedding_retrieval(self): + docstore = InMemoryDocumentStore(embedding_similarity_function="cosine") + # Tests if the embedding retrieval method returns the correct document based on the input query embedding. + docs = [ + Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), + Document( + content="Haystack supports multiple languages", + embedding=[1.0, 1.0, 1.0, 1.0], + ), + ] + await docstore.write_documents_async(docs) + results = await docstore.embedding_retrieval_async( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters={}, scale_score=False + ) + assert len(results) == 1 + assert results[0].content == "Haystack supports multiple languages"