diff --git a/integrations/amazon_bedrock/examples/chatgenerator_example.py b/integrations/amazon_bedrock/examples/chatgenerator_example.py index 4617a81fc..b67be9f89 100755 --- a/integrations/amazon_bedrock/examples/chatgenerator_example.py +++ b/integrations/amazon_bedrock/examples/chatgenerator_example.py @@ -6,9 +6,7 @@ from haystack.dataclasses import ChatMessage -from haystack_integrations.components.generators.amazon_bedrock import ( - AmazonBedrockChatGenerator, -) +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator generator = AmazonBedrockChatGenerator( model="anthropic.claude-3-haiku-20240307-v1:0", @@ -31,9 +29,7 @@ # which allows for more portablability of code across generators messages = [ ChatMessage.from_system(system_prompt), - ChatMessage.from_user( - "Which service should I use to train custom Machine Learning models?" - ), + ChatMessage.from_user("Which service should I use to train custom Machine Learning models?"), ] results = generator.run(messages) diff --git a/integrations/amazon_bedrock/examples/embedders_generator_with_rag_example.py b/integrations/amazon_bedrock/examples/embedders_generator_with_rag_example.py index 8686331c8..39f7ee4c5 100644 --- a/integrations/amazon_bedrock/examples/embedders_generator_with_rag_example.py +++ b/integrations/amazon_bedrock/examples/embedders_generator_with_rag_example.py @@ -12,9 +12,7 @@ AmazonBedrockDocumentEmbedder, AmazonBedrockTextEmbedder, ) -from haystack_integrations.components.generators.amazon_bedrock import ( - AmazonBedrockGenerator, -) +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator generator_model_name = "amazon.titan-text-lite-v1" embedder_model_name = "amazon.titan-embed-text-v1" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py index 5a4e8e13b..3148818c1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/common/amazon_bedrock/utils.py @@ -45,9 +45,7 @@ def get_aws_session( profile_name=aws_profile_name, ) except BotoCoreError as e: - provided_aws_config = { - k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS - } + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" raise AWSConfigurationError(msg) from e diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py index 804f5cccd..1b8fde124 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py @@ -62,21 +62,13 @@ def __init__( "cohere.embed-multilingual-v3", "amazon.titan-embed-text-v2:0", ], - aws_access_key_id: Optional[Secret] = Secret.from_env_var( - "AWS_ACCESS_KEY_ID", strict=False - ), # noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 "AWS_SECRET_ACCESS_KEY", strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var( - "AWS_SESSION_TOKEN", strict=False - ), # noqa: B008 - aws_region_name: Optional[Secret] = Secret.from_env_var( - "AWS_DEFAULT_REGION", strict=False - ), # noqa: B008 - aws_profile_name: Optional[Secret] = Secret.from_env_var( - "AWS_PROFILE", strict=False - ), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -113,9 +105,8 @@ def __init__( """ if not model or model not in SUPPORTED_EMBEDDING_MODELS: - msg = ( - "Please provide a valid model from the list of supported models: " - + ", ".join(SUPPORTED_EMBEDDING_MODELS) + msg = "Please provide a valid model from the list of supported models: " + ", ".join( + SUPPORTED_EMBEDDING_MODELS ) raise ValueError(msg) @@ -156,15 +147,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ texts_to_embed = [] for doc in documents: - meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if doc.meta.get(key) - ] - - text_to_embed = self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) + meta_values_to_embed = [str(doc.meta[key]) for key in self.meta_fields_to_embed if doc.meta.get(key)] + + text_to_embed = self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) texts_to_embed.append(text_to_embed) return texts_to_embed @@ -178,28 +163,21 @@ def _embed_cohere(self, documents: List[Document]) -> List[Document]: texts_to_embed = self._prepare_texts_to_embed(documents=documents) cohere_body = { - "input_type": self.kwargs.get( - "input_type", "search_document" - ), # mandatory parameter for Cohere models + "input_type": self.kwargs.get("input_type", "search_document"), # mandatory parameter for Cohere models } if truncate := self.kwargs.get("truncate"): cohere_body["truncate"] = truncate # optional parameter for Cohere models all_embeddings = [] for i in tqdm( - range(0, len(texts_to_embed), self.batch_size), - disable=not self.progress_bar, - desc="Creating embeddings", + range(0, len(texts_to_embed), self.batch_size), disable=not self.progress_bar, desc="Creating embeddings" ): batch = texts_to_embed[i : i + self.batch_size] body = {"texts": batch, **cohere_body} try: response = self._client.invoke_model( - body=json.dumps(body), - modelId=self.model, - accept="*/*", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" ) except ClientError as exception: msg = ( @@ -226,16 +204,11 @@ def _embed_titan(self, documents: List[Document]) -> List[Document]: texts_to_embed = self._prepare_texts_to_embed(documents=documents) all_embeddings = [] - for text in tqdm( - texts_to_embed, disable=not self.progress_bar, desc="Creating embeddings" - ): + for text in tqdm(texts_to_embed, disable=not self.progress_bar, desc="Creating embeddings"): body = {"inputText": text} try: response = self._client.invoke_model( - body=json.dumps(body), - modelId=self.model, - accept="*/*", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" ) except ClientError as exception: msg = ( @@ -263,11 +236,7 @@ def run(self, documents: List[Document]): - `documents`: The `Document`s with the `embedding` field populated. :raises AmazonBedrockInferenceError: If the inference fails. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "AmazonBedrockDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the AmazonBedrockTextEmbedder." @@ -290,21 +259,11 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - aws_access_key_id=self.aws_access_key_id.to_dict() - if self.aws_access_key_id - else None, - aws_secret_access_key=self.aws_secret_access_key.to_dict() - if self.aws_secret_access_key - else None, - aws_session_token=self.aws_session_token.to_dict() - if self.aws_session_token - else None, - aws_region_name=self.aws_region_name.to_dict() - if self.aws_region_name - else None, - aws_profile_name=self.aws_profile_name.to_dict() - if self.aws_profile_name - else None, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, batch_size=self.batch_size, progress_bar=self.progress_bar, @@ -325,12 +284,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockDocumentEmbedder": """ deserialize_secrets_inplace( data["init_parameters"], - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py index 5de7e33aa..0cceda92f 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py @@ -55,21 +55,13 @@ def __init__( "cohere.embed-multilingual-v3", "amazon.titan-embed-text-v2:0", ], - aws_access_key_id: Optional[Secret] = Secret.from_env_var( - "AWS_ACCESS_KEY_ID", strict=False - ), # noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 "AWS_SECRET_ACCESS_KEY", strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var( - "AWS_SESSION_TOKEN", strict=False - ), # noqa: B008 - aws_region_name: Optional[Secret] = Secret.from_env_var( - "AWS_DEFAULT_REGION", strict=False - ), # noqa: B008 - aws_profile_name: Optional[Secret] = Secret.from_env_var( - "AWS_PROFILE", strict=False - ), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 **kwargs, ): """ @@ -95,9 +87,8 @@ def __init__( :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly. """ if not model or model not in SUPPORTED_EMBEDDING_MODELS: - msg = ( - "Please provide a valid model from the list of supported models: " - + ", ".join(SUPPORTED_EMBEDDING_MODELS) + msg = "Please provide a valid model from the list of supported models: " + ", ".join( + SUPPORTED_EMBEDDING_MODELS ) raise ValueError(msg) @@ -148,9 +139,7 @@ def run(self, text: str): if "cohere" in self.model: body = { "texts": [text], - "input_type": self.kwargs.get( - "input_type", "search_query" - ), # mandatory parameter for Cohere models + "input_type": self.kwargs.get("input_type", "search_query"), # mandatory parameter for Cohere models } if truncate := self.kwargs.get("truncate"): body["truncate"] = truncate # optional parameter for Cohere models @@ -162,10 +151,7 @@ def run(self, text: str): try: response = self._client.invoke_model( - body=json.dumps(body), - modelId=self.model, - accept="*/*", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="*/*", contentType="application/json" ) except ClientError as exception: msg = ( @@ -193,21 +179,11 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - aws_access_key_id=self.aws_access_key_id.to_dict() - if self.aws_access_key_id - else None, - aws_secret_access_key=self.aws_secret_access_key.to_dict() - if self.aws_secret_access_key - else None, - aws_session_token=self.aws_session_token.to_dict() - if self.aws_session_token - else None, - aws_region_name=self.aws_region_name.to_dict() - if self.aws_region_name - else None, - aws_profile_name=self.aws_profile_name.to_dict() - if self.aws_profile_name - else None, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, **self.kwargs, ) @@ -224,12 +200,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockTextEmbedder": """ deserialize_secrets_inplace( data["init_parameters"], - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index 162a814c4..7c7fdd7ce 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -39,9 +39,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]: responses = [completion.lstrip() for completion in completions] return responses - def get_stream_responses( - self, stream, stream_handler: TokenStreamingHandler - ) -> List[str]: + def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: """ Extracts the responses from the Amazon Bedrock streaming response. @@ -59,9 +57,7 @@ def get_stream_responses( responses = ["".join(tokens).lstrip()] return responses - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any] - ) -> Dict[str, Any]: + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: """ Merges the default params with the inference kwargs and model kwargs. @@ -79,9 +75,7 @@ def _get_params( } @abstractmethod - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Amazon Bedrock response. @@ -144,9 +138,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Amazon Bedrock response. @@ -195,14 +187,10 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: } params = self._get_params(inference_kwargs, default_params) # Add the instruction tag to the prompt if it's not already there - formatted_prompt = ( - f"[INST] {prompt} [/INST]" if "INST" not in prompt else prompt - ) + formatted_prompt = f"[INST] {prompt} [/INST]" if "INST" not in prompt else prompt return {"prompt": formatted_prompt, **params} - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Amazon Bedrock response. @@ -256,9 +244,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Cohere Command model response. @@ -317,9 +303,7 @@ def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]: body = {"message": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Cohere Command model response. @@ -369,12 +353,8 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: - responses = [ - completion["data"]["text"] for completion in response_body["completions"] - ] + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + responses = [completion["data"]["text"] for completion in response_body["completions"]] return responses def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -408,9 +388,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"inputText": prompt, "textGenerationConfig": params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Titan model response. @@ -455,9 +433,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """ Extracts the responses from the Llama2 model response. diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index fe3897bea..162100934 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -8,9 +8,7 @@ from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from transformers import AutoTokenizer, PreTrainedTokenizer -from haystack_integrations.components.generators.amazon_bedrock.handlers import ( - DefaultPromptHandler, -) +from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler logger = logging.getLogger(__name__) @@ -30,9 +28,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: self.generation_kwargs = generation_kwargs @abstractmethod - def prepare_body( - self, messages: List[ChatMessage], **inference_kwargs - ) -> Dict[str, Any]: + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Amazon Bedrock request. Subclasses should override this method to package the chat messages into the request. @@ -61,25 +57,14 @@ def get_stream_responses( if chunk: last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) token = self._extract_token_from_stream(last_decoded_chunk) - stream_chunk = StreamingChunk( - content=token - ) # don't extract meta, we care about tokens only - stream_handler( - stream_chunk - ) # callback the stream handler with StreamingChunk + stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only + stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] - return [ - ChatMessage.from_assistant(response, meta=last_decoded_chunk) - for response in responses - ] + return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod - def _update_params( - target_dict: Dict[str, Any], - updates_dict: Dict[str, Any], - allowed_params: List[str], - ) -> None: + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. @@ -91,11 +76,7 @@ def _update_params( if key not in allowed_params: logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") continue - if ( - key in target_dict - and isinstance(target_dict[key], list) - and isinstance(value, list) - ): + if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates target_dict[key] = sorted(set(target_dict[key] + value)) else: @@ -103,10 +84,7 @@ def _update_params( target_dict[key] = value def _get_params( - self, - inference_kwargs: Dict[str, Any], - default_params: Dict[str, Any], - allowed_params: List[str], + self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] ) -> Dict[str, Any]: """ Merges params from inference_kwargs with the default params and self.generation_kwargs. @@ -155,9 +133,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @abstractmethod - def _extract_messages_from_response( - self, response_body: Dict[str, Any] - ) -> List[ChatMessage]: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. @@ -215,9 +191,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): max_length=self.generation_kwargs.get("max_tokens") or 512, ) - def prepare_body( - self, messages: List[ChatMessage], **inference_kwargs - ) -> Dict[str, Any]: + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Anthropic Claude request. @@ -226,16 +200,12 @@ def prepare_body( :returns: The prepared body. """ default_params = { - "anthropic_version": self.generation_kwargs.get("anthropic_version") - or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") - or 512, # max_tokens is required + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it - stop_sequences = inference_kwargs.get( - "stop_sequences", [] - ) + inference_kwargs.pop("stop_words", []) + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) @@ -250,15 +220,9 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: :returns: The prepared chat messages as a string. """ body: Dict[str, Any] = {} - system = ( - messages[0].content - if messages and messages[0].is_from(ChatRole.SYSTEM) - else None - ) + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None body["messages"] = [ - self._to_anthropic_message(m) - for m in messages - if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) ] if system: body["system"] = system @@ -273,9 +237,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def _extract_messages_from_response( - self, response_body: Dict[str, Any] - ) -> List[ChatMessage]: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. @@ -286,14 +248,8 @@ def _extract_messages_from_response( if response_body.get("type") == "message": for content in response_body["content"]: if content.get("type") == "text": - meta = { - k: v - for k, v in response_body.items() - if k not in ["type", "content", "role"] - } - messages.append( - ChatMessage.from_assistant(content["text"], meta=meta) - ) + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -303,10 +259,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: The extracted token. """ - if ( - chunk.get("type") == "content_block_delta" - and chunk.get("delta", {}).get("type") == "text_delta" - ): + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": return chunk.get("delta", {}).get("text", "") return "" @@ -383,9 +336,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. tokenizer: PreTrainedTokenizer if os.environ.get("HF_TOKEN"): - tokenizer = AutoTokenizer.from_pretrained( - "mistralai/Mistral-7B-Instruct-v0.1" - ) + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") else: tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") logger.warning( @@ -401,9 +352,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): max_length=self.generation_kwargs.get("max_tokens") or 512, ) - def prepare_body( - self, messages: List[ChatMessage], **inference_kwargs - ) -> Dict[str, Any]: + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Mistral request. @@ -412,8 +361,7 @@ def prepare_body( :returns: The prepared body. """ default_params = { - "max_tokens": self.generation_kwargs.get("max_tokens") - or 512, # max_tokens is required + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter stop_words = inference_kwargs.pop("stop_words", []) @@ -435,9 +383,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json # but we'll use our custom chat template prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=[self.to_openai_format(m) for m in messages], - tokenize=False, - chat_template=self.chat_template, + conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template ) return self._ensure_token_limit(prepared_prompt) @@ -465,9 +411,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def _extract_messages_from_response( - self, response_body: Dict[str, Any] - ) -> List[ChatMessage]: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. @@ -542,9 +486,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 # a) we should get good estimates for the prompt length (empirically close to llama 2) # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - "google/flan-t5-base" - ) + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.unk_token = "" @@ -554,18 +496,14 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: max_length=self.generation_kwargs.get("max_gen_len") or 512, ) - def prepare_body( - self, messages: List[ChatMessage], **inference_kwargs - ) -> Dict[str, Any]: + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Meta Llama 2 request. :param messages: The chat messages to package into the request. :param inference_kwargs: Additional inference kwargs to use. """ - default_params = { - "max_gen_len": self.generation_kwargs.get("max_gen_len") or 512 - } + default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} # no support for stop words in Meta Llama 2 params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) @@ -594,9 +532,7 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def _extract_messages_from_response( - self, response_body: Dict[str, Any] - ) -> List[ChatMessage]: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index f35258cd6..73f70eee2 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -8,10 +8,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace -from haystack.utils.callable_serialization import ( - deserialize_callable, - serialize_callable, -) +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -19,12 +16,7 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from .adapters import ( - AnthropicClaudeChatAdapter, - BedrockModelChatAdapter, - MetaLlama2ChatAdapter, - MistralChatAdapter, -) +from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter logger = logging.getLogger(__name__) @@ -65,21 +57,13 @@ class AmazonBedrockChatGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[Secret] = Secret.from_env_var( - ["AWS_ACCESS_KEY_ID"], strict=False - ), # noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 ["AWS_SECRET_ACCESS_KEY"], strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var( - ["AWS_SESSION_TOKEN"], strict=False - ), # noqa: B008 - aws_region_name: Optional[Secret] = Secret.from_env_var( - ["AWS_DEFAULT_REGION"], strict=False - ), # noqa: B008 - aws_profile_name: Optional[Secret] = Secret.from_env_var( - ["AWS_PROFILE"], strict=False - ), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, @@ -179,16 +163,11 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body( - messages=messages, **{"stop_words": self.stop_words, **kwargs} - ) + body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) try: if self.streaming_callback: response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), - modelId=self.model, - accept="application/json", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_stream = response["body"] responses = self.model_adapter.get_stream_responses( @@ -196,15 +175,10 @@ def invoke(self, *args, **kwargs): ) else: response = self.client.invoke_model( - body=json.dumps(body), - modelId=self.model, - accept="application/json", - contentType="application/json", + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses( - response_body=response_body - ) + responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception @@ -212,11 +186,7 @@ def invoke(self, *args, **kwargs): return responses @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - generation_kwargs: Optional[Dict[str, Any]] = None, - ): + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): """ Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. @@ -247,28 +217,14 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = ( - serialize_callable(self.streaming_callback) - if self.streaming_callback - else None - ) + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, - aws_access_key_id=self.aws_access_key_id.to_dict() - if self.aws_access_key_id - else None, - aws_secret_access_key=self.aws_secret_access_key.to_dict() - if self.aws_secret_access_key - else None, - aws_session_token=self.aws_session_token.to_dict() - if self.aws_session_token - else None, - aws_region_name=self.aws_region_name.to_dict() - if self.aws_region_name - else None, - aws_profile_name=self.aws_profile_name.to_dict() - if self.aws_profile_name - else None, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, stop_words=self.stop_words, generation_kwargs=self.model_adapter.generation_kwargs, @@ -288,17 +244,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable( - serialized_callback_handler - ) + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) deserialize_secrets_inplace( data["init_parameters"], - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 2da289aba..32d1de629 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -67,21 +67,13 @@ class AmazonBedrockGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[Secret] = Secret.from_env_var( - "AWS_ACCESS_KEY_ID", strict=False - ), # noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 "AWS_SECRET_ACCESS_KEY", strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var( - "AWS_SESSION_TOKEN", strict=False - ), # noqa: B008 - aws_region_name: Optional[Secret] = Secret.from_env_var( - "AWS_DEFAULT_REGION", strict=False - ), # noqa: B008 - aws_profile_name: Optional[Secret] = Secret.from_env_var( - "AWS_PROFILE", strict=False - ), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, truncate: Optional[bool] = True, **kwargs, @@ -152,13 +144,9 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: if not model_adapter_cls: msg = f"AmazonBedrockGenerator doesn't support the model {model}." raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls( - model_kwargs=model_input_kwargs, max_length=self.max_length - ) + self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) - def _ensure_token_limit( - self, prompt: Union[str, List[Dict[str, str]]] - ) -> Union[str, List[Dict[str, str]]]: + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: """ Ensures that the prompt and answer token lengths together are within the model_max_length specified during the initialization of the component. @@ -197,9 +185,7 @@ def invoke(self, *args, **kwargs): """ kwargs = kwargs.copy() prompt: str = kwargs.pop("prompt", None) - stream: bool = kwargs.get( - "stream", self.model_adapter.model_kwargs.get("stream", False) - ) + stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) if not prompt or not isinstance(prompt, (str, list)): msg = ( @@ -223,13 +209,9 @@ def invoke(self, *args, **kwargs): response_stream = response["body"] handler: TokenStreamingHandler = kwargs.get( "stream_handler", - self.model_adapter.model_kwargs.get( - "stream_handler", DefaultTokenStreamingHandler() - ), - ) - responses = self.model_adapter.get_stream_responses( - stream=response_stream, stream_handler=handler + self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()), ) + responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( body=json.dumps(body), @@ -238,9 +220,7 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses( - response_body=response_body - ) + responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = ( f"Could not connect to Amazon Bedrock model {self.model}. " @@ -287,21 +267,11 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - aws_access_key_id=self.aws_access_key_id.to_dict() - if self.aws_access_key_id - else None, - aws_secret_access_key=self.aws_secret_access_key.to_dict() - if self.aws_secret_access_key - else None, - aws_session_token=self.aws_session_token.to_dict() - if self.aws_session_token - else None, - aws_region_name=self.aws_region_name.to_dict() - if self.aws_region_name - else None, - aws_profile_name=self.aws_profile_name.to_dict() - if self.aws_profile_name - else None, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, truncate=self.truncate, @@ -320,12 +290,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": """ deserialize_secrets_inplace( data["init_parameters"], - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 67cd2061a..f4dc1aa4f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,12 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Union -from transformers import ( - AutoTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, -) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast class DefaultPromptHandler: @@ -15,12 +10,7 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__( - self, - tokenizer: Union[str, PreTrainedTokenizerBase], - model_max_length: int, - max_length: int = 100, - ): + def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): """ :param tokenizer: The tokenizer to be used to tokenize the prompt. :param model_max_length: The maximum length of the prompt and answer tokens combined. @@ -62,9 +52,7 @@ def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: resized_prompt = self.tokenizer.convert_tokens_to_string( tokenized_prompt[: self.model_max_length - self.max_length] ) - new_prompt_length = len( - tokenized_prompt[: self.model_max_length - self.max_length] - ) + new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length]) return { "resized_prompt": resized_prompt, diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index cd9fef803..3e62b56ea 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -7,9 +7,7 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from haystack_integrations.components.generators.amazon_bedrock import ( - AmazonBedrockChatGenerator, -) +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( AnthropicClaudeChatAdapter, BedrockModelChatAdapter, @@ -18,11 +16,7 @@ ) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = [ - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-v2:1", - "meta.llama2-13b-chat-v1", -] +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] MISTRAL_MODELS = [ "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", @@ -42,31 +36,11 @@ def test_to_dict(mock_boto3_session): expected_dict = { "type": KLASS, "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], @@ -85,31 +59,11 @@ def test_from_dict(mock_boto3_session): { "type": KLASS, "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -154,9 +108,7 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", generation_kwargs=generation_kwargs - ) + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) assert "temperature" in layer.model_adapter.generation_kwargs assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 @@ -191,9 +143,7 @@ def test_invoke_with_no_kwargs(mock_boto3_session): ("unknown_model", None), ], ) -def test_get_model_adapter( - model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]] -): +def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): """ Test that the correct model adapter is returned for a given model """ @@ -208,12 +158,7 @@ def test_prepare_body_with_default_params(self) -> None: expected_body = { "anthropic_version": "bedrock-2023-05-31", "max_tokens": 512, - "messages": [ - { - "content": [{"text": "Hello, how are you?", "type": "text"}], - "role": "user", - } - ], + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], } body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -221,19 +166,12 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter( - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} - ) + layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "anthropic_version": "bedrock-2023-05-31", "max_tokens": 512, - "messages": [ - { - "content": [{"text": "Hello, how are you?", "type": "text"}], - "role": "user", - } - ], + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_k": 5, @@ -241,11 +179,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: } body = layer.prepare_body( - [ChatMessage.from_user(prompt)], - top_p=0.8, - top_k=5, - max_tokens_to_sample=69, - stop_sequences=["CUSTOM_STOP"], + [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] ) assert body == expected_body @@ -265,9 +199,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter( - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} - ) + layer = MistralChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", @@ -276,35 +208,19 @@ def test_prepare_body_with_custom_inference_params(self) -> None: "top_p": 0.8, } - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69 - ) + body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69) assert body == expected_body def test_mistral_chat_template_correct_order(self): layer = MistralChatAdapter(generation_kwargs={}) - layer.prepare_body( - [ - ChatMessage.from_user("A"), - ChatMessage.from_assistant("B"), - ChatMessage.from_user("C"), - ] - ) - layer.prepare_body( - [ - ChatMessage.from_system("A"), - ChatMessage.from_user("B"), - ChatMessage.from_assistant("C"), - ] - ) + layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) + layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) def test_mistral_chat_template_incorrect_order(self): layer = MistralChatAdapter(generation_kwargs={}) try: - layer.prepare_body( - [ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")] - ) + layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) msg = "Expected TemplateError" raise AssertionError(msg) except Exception as e: @@ -318,9 +234,7 @@ def test_mistral_chat_template_incorrect_order(self): assert "Conversation roles must alternate user/assistant/" in str(e) try: - layer.prepare_body( - [ChatMessage.from_system("A"), ChatMessage.from_system("B")] - ) + layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")]) msg = "Expected TemplateError" raise AssertionError(msg) except Exception as e: @@ -330,9 +244,7 @@ def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None monkeypatch.delenv("HF_TOKEN", raising=False) with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch( - "haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler" - ), + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), caplog.at_level(logging.WARNING), ): MistralChatAdapter(generation_kwargs={}) @@ -343,9 +255,7 @@ def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: monkeypatch.setenv("HF_TOKEN", "test") with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch( - "haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler" - ), + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), ): MistralChatAdapter(generation_kwargs={}) mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") @@ -369,25 +279,17 @@ def test_default_inference_params(self, model_name, chat_messages): assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "paris" in first_reply.content.lower() - ), "First reply does not contain 'paris'" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.fixture def chat_messages(): messages = [ - ChatMessage.from_system( - "\\nYou are a helpful assistant, be super brief in your responses." - ), + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), ChatMessage.from_user("What's the capital of France?"), ] return messages @@ -400,10 +302,7 @@ def test_prepare_body_with_default_params(self) -> None: # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_gen_len": 512, - } + expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -414,12 +313,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter( - generation_kwargs={ - "temperature": 0.7, - "top_p": 0.8, - "top_k": 5, - "stop_sequences": ["CUSTOM_STOP"], - } + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} ) prompt = "Hello, how are you?" @@ -457,6 +351,7 @@ def test_get_responses(self) -> None: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): + client = AmazonBedrockChatGenerator(model=model_name) response = client.run(chat_messages) @@ -466,16 +361,10 @@ def test_default_inference_params(self, model_name, chat_messages): assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "paris" in first_reply.content.lower() - ), "First reply does not contain 'paris'" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @@ -492,28 +381,18 @@ def streaming_callback(chunk: StreamingChunk): if not paris_found_in_response: paris_found_in_response = "paris" in chunk.content.lower() - client = AmazonBedrockChatGenerator( - model=model_name, streaming_callback=streaming_callback - ) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) response = client.run(chat_messages) assert streaming_callback_called, "Streaming callback was not called" - assert ( - paris_found_in_response - ), "The streaming callback response did not contain 'paris'" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "paris" in first_reply.content.lower() - ), "First reply does not contain 'paris'" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" diff --git a/integrations/amazon_bedrock/tests/test_document_embedder.py b/integrations/amazon_bedrock/tests/test_document_embedder.py index eff9610ce..9856c97bb 100644 --- a/integrations/amazon_bedrock/tests/test_document_embedder.py +++ b/integrations/amazon_bedrock/tests/test_document_embedder.py @@ -9,9 +9,7 @@ AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) -from haystack_integrations.components.embedders.amazon_bedrock import ( - AmazonBedrockDocumentEmbedder, -) +from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentEmbedder TYPE = "haystack_integrations.components.embedders.amazon_bedrock.document_embedder.AmazonBedrockDocumentEmbedder" @@ -77,31 +75,11 @@ def test_to_dict(self, mock_boto3_session): expected_dict = { "type": TYPE, "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_document", "batch_size": 32, @@ -117,31 +95,11 @@ def test_from_dict(self, mock_boto3_session): data = { "type": TYPE, "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_document", "batch_size": 32, @@ -177,9 +135,7 @@ def test_run_invocation_error(self, mock_boto3_session): with patch.object(embedder._client, "invoke_model") as mock_invoke_model: mock_invoke_model.side_effect = ClientError( - error_response={ - "Error": {"Code": "some_code", "Message": "some_message"} - }, + error_response={"Error": {"Code": "some_code", "Message": "some_message"}}, operation_name="some_operation", ) @@ -190,17 +146,11 @@ def test_run_invocation_error(self, mock_boto3_session): def test_prepare_texts_to_embed_w_metadata(self, mock_boto3_session): documents = [ - Document( - content=f"document number {i}: content", - meta={"meta_field": f"meta_value {i}"}, - ) - for i in range(5) + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] embedder = AmazonBedrockDocumentEmbedder( - model="cohere.embed-english-v3", - meta_fields_to_embed=["meta_field"], - embedding_separator=" | ", + model="cohere.embed-english-v3", meta_fields_to_embed=["meta_field"], embedding_separator=" | " ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -218,9 +168,7 @@ def test_embed_cohere(self, mock_boto3_session): with patch.object(embedder, "_client") as mock_client: mock_client.invoke_model.return_value = { - "body": io.StringIO( - '{"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]}' - ), + "body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]}'), } docs = [Document(content="some text"), Document(content="some other text")] @@ -240,9 +188,7 @@ def test_embed_cohere(self, mock_boto3_session): assert result[1].embedding == [0.4, 0.5, 0.6] def test_embed_cohere_batching(self, mock_boto3_session): - embedder = AmazonBedrockDocumentEmbedder( - model="cohere.embed-english-v3", batch_size=2 - ) + embedder = AmazonBedrockDocumentEmbedder(model="cohere.embed-english-v3", batch_size=2) mock_response = { "body": io.StringIO('{"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]}'), @@ -290,22 +236,10 @@ def mock_invoke_model(*args, **kwargs): result = embedder._embed_titan(documents=docs) assert mock_client.invoke_model.call_count == 2 - assert ( - mock_client.invoke_model.call_args_list[0][1]["modelId"] - == "amazon.titan-embed-text-v1" - ) - assert ( - mock_client.invoke_model.call_args_list[0][1]["body"] - == '{"inputText": "some text"}' - ) - assert ( - mock_client.invoke_model.call_args_list[1][1]["modelId"] - == "amazon.titan-embed-text-v1" - ) - assert ( - mock_client.invoke_model.call_args_list[1][1]["body"] - == '{"inputText": "some other text"}' - ) + assert mock_client.invoke_model.call_args_list[0][1]["modelId"] == "amazon.titan-embed-text-v1" + assert mock_client.invoke_model.call_args_list[0][1]["body"] == '{"inputText": "some text"}' + assert mock_client.invoke_model.call_args_list[1][1]["modelId"] == "amazon.titan-embed-text-v1" + assert mock_client.invoke_model.call_args_list[1][1]["body"] == '{"inputText": "some other text"}' for i, doc in enumerate(result): assert doc.content == docs[i].content diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index a9801cad0..65463caae 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -3,9 +3,7 @@ import pytest -from haystack_integrations.components.generators.amazon_bedrock import ( - AmazonBedrockGenerator, -) +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, @@ -22,38 +20,16 @@ def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10 - ) + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, "truncate": False, @@ -72,31 +48,11 @@ def test_from_dict(mock_boto3_session): { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, }, @@ -136,15 +92,11 @@ def test_default_constructor(mock_boto3_session, set_env_variables): ) -def test_constructor_prompt_handler_initialized( - mock_boto3_session, mock_prompt_handler -): +def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_handler): """ Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 """ - layer = AmazonBedrockGenerator( - model="anthropic.claude-v2", prompt_handler=mock_prompt_handler - ) + layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler) assert layer.prompt_handler is not None assert layer.prompt_handler.model_max_length == 4096 @@ -173,9 +125,7 @@ def test_invoke_with_no_kwargs(mock_boto3_session): Test invoke raises an error if no prompt is provided """ layer = AmazonBedrockGenerator(model="anthropic.claude-v2") - with pytest.raises( - ValueError, match="The model anthropic.claude-v2 requires a valid prompt." - ): + with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): layer.invoke() @@ -197,9 +147,7 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): max_length_generated_text = 3 total_model_max_length = 10 - with patch( - "transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer - ): + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( "anthropic.claude-v2", max_length=max_length_generated_text, @@ -233,9 +181,7 @@ def test_long_prompt_is_truncated(mock_boto3_session): max_length_generated_text = 3 total_model_max_length = 10 - with patch( - "transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer - ): + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( "anthropic.claude-v2", max_length=max_length_generated_text, @@ -273,11 +219,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.model_adapter.prepare_body = MagicMock(return_value={}) generator.client = MagicMock() generator.client.invoke_model = MagicMock( - return_value={ - "body": MagicMock( - read=MagicMock(return_value=b'{"generated_text": "response"}') - ) - } + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} ) generator.model_adapter.get_responses = MagicMock(return_value=["response"]) @@ -324,9 +266,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("unknown_model", None), ], ) -def test_get_model_adapter( - model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]] -): +def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ Test that the correct model adapter is returned for a given model """ @@ -340,9 +280,7 @@ def test_default_init(self) -> None: assert adapter.use_messages_api is True def test_use_messages_api_false(self) -> None: - adapter = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=100 - ) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=100) assert adapter.use_messages_api is False @@ -479,16 +417,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"delta": {"text": " response."}}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -506,25 +439,18 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() class TestAnthropicClaudeAdapterNoMessagesAPI: def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", @@ -537,9 +463,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", @@ -611,24 +535,18 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non "top_k": 5, } - body = layer.prepare_body( - prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens_to_sample=50 - ) + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens_to_sample=50) assert body == expected_body def test_get_responses(self) -> None: - adapter = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -645,18 +563,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"completion": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -674,18 +585,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter( - model_kwargs={"use_messages_api": False}, max_length=99 - ) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -694,11 +598,7 @@ class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: layer = MistralAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_tokens": 99, - "stop": [], - } + expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_tokens": 99, "stop": []} body = layer.prepare_body(prompt) assert body == expected_body @@ -774,9 +674,7 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non "top_k": 5, } - body = layer.prepare_body( - prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens=50 - ) + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens=50) assert body == expected_body @@ -798,16 +696,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputs": [{"text": " response."}]}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -825,16 +718,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -1001,23 +889,14 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"text": " a"}'}}, {"chunk": {"bytes": b'{"text": " single"}'}}, {"chunk": {"bytes": b'{"text": " response."}'}}, - { - "chunk": { - "bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}' - } - }, + {"chunk": {"bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -1026,9 +905,7 @@ def test_get_stream_responses(self) -> None: call(" a", event_data={"text": " a"}), call(" single", event_data={"text": " single"}), call(" response.", event_data={"text": " response."}), - call( - "", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True} - ), + call("", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True}), ] ) @@ -1038,16 +915,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -1061,10 +933,7 @@ def test_prepare_body(self) -> None: ], "documents": [ {"title": "France", "snippet": "Paris is the capital of France."}, - { - "title": "Germany", - "snippet": "Berlin is the capital of Germany.", - }, + {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, ], "search_query_only": False, "preamble": "preamble", @@ -1092,15 +961,9 @@ def test_prepare_body(self) -> None: ], "tool_results": [ { - "call": { - "name": "query_daily_sales_report", - "parameters": {"day": "2023-09-29"}, - }, + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, "outputs": [ - { - "date": "2023-09-29", - "summary": "Total Sales Amount: 10000, Total Units Sold: 250", - } + {"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"} ], } ], @@ -1148,16 +1011,8 @@ def test_prepare_body(self) -> None: ], "tool_results": [ { - "call": { - "name": "query_daily_sales_report", - "parameters": {"day": "2023-09-29"}, - }, - "outputs": [ - { - "date": "2023-09-29", - "summary": "Total Sales Amount: 10000, Total Units Sold: 250", - } - ], + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "outputs": [{"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}], } ], "stop_sequences": ["\n\n"], @@ -1167,9 +1022,7 @@ def test_prepare_body(self) -> None: def test_extract_completions_from_response(self) -> None: adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) response_body = {"text": "response"} - completions = adapter._extract_completions_from_response( - response_body=response_body - ) + completions = adapter._extract_completions_from_response(response_body=response_body) assert completions == ["response"] def test_extract_token_from_stream(self) -> None: @@ -1294,17 +1147,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non def test_get_responses(self) -> None: adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) - response_body = { - "completions": [{"data": {"text": "This is a single response."}}] - } + response_body = {"completions": [{"data": {"text": "This is a single response."}}]} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) - response_body = { - "completions": [{"data": {"text": "\n\t This is a single response."}}] - } + response_body = {"completions": [{"data": {"text": "\n\t This is a single response."}}]} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -1449,16 +1298,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputText": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -1476,16 +1320,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -1588,16 +1427,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"generation": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -1615,15 +1449,10 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py index 150e691a5..4f4e92448 100644 --- a/integrations/amazon_bedrock/tests/test_text_embedder.py +++ b/integrations/amazon_bedrock/tests/test_text_embedder.py @@ -8,9 +8,7 @@ AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) -from haystack_integrations.components.embedders.amazon_bedrock import ( - AmazonBedrockTextEmbedder, -) +from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder class TestAmazonBedrockTextEmbedder: @@ -45,6 +43,7 @@ def test_connection_error(self, mock_boto3_session): ) def test_to_dict(self, mock_boto3_session): + embedder = AmazonBedrockTextEmbedder( model="cohere.embed-english-v3", input_type="search_query", @@ -53,31 +52,11 @@ def test_to_dict(self, mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_query", }, @@ -86,34 +65,15 @@ def test_to_dict(self, mock_boto3_session): assert embedder.to_dict() == expected_dict def test_from_dict(self, mock_boto3_session): + data = { "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_query", }, @@ -177,9 +137,7 @@ def test_run_invocation_error(self, mock_boto3_session): with patch.object(embedder._client, "invoke_model") as mock_invoke_model: mock_invoke_model.side_effect = ClientError( - error_response={ - "Error": {"Code": "some_code", "Message": "some_message"} - }, + error_response={"Error": {"Code": "some_code", "Message": "some_message"}}, operation_name="some_operation", ) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py index ed0251636..0fe45a8a1 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from haystack_integrations.components.generators.amazon_sagemaker.sagemaker import ( - SagemakerGenerator, -) +from haystack_integrations.components.generators.amazon_sagemaker.sagemaker import SagemakerGenerator __all__ = ["SagemakerGenerator"] diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py index 842859cee..2a04d6a2a 100644 --- a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -50,21 +50,13 @@ class SagemakerGenerator: def __init__( self, model: str, - aws_access_key_id: Optional[Secret] = Secret.from_env_var( - ["AWS_ACCESS_KEY_ID"], strict=False - ), # noqa: B008 + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 ["AWS_SECRET_ACCESS_KEY"], strict=False ), - aws_session_token: Optional[Secret] = Secret.from_env_var( - ["AWS_SESSION_TOKEN"], strict=False - ), # noqa: B008 - aws_region_name: Optional[Secret] = Secret.from_env_var( - ["AWS_DEFAULT_REGION"], strict=False - ), # noqa: B008 - aws_profile_name: Optional[Secret] = Secret.from_env_var( - ["AWS_PROFILE"], strict=False - ), # noqa: B008 + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 aws_custom_attributes: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): @@ -141,21 +133,11 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model, - aws_access_key_id=self.aws_access_key_id.to_dict() - if self.aws_access_key_id - else None, - aws_secret_access_key=self.aws_secret_access_key.to_dict() - if self.aws_secret_access_key - else None, - aws_session_token=self.aws_session_token.to_dict() - if self.aws_session_token - else None, - aws_region_name=self.aws_region_name.to_dict() - if self.aws_region_name - else None, - aws_profile_name=self.aws_profile_name.to_dict() - if self.aws_profile_name - else None, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, aws_custom_attributes=self.aws_custom_attributes, generation_kwargs=self.generation_kwargs, ) @@ -172,13 +154,7 @@ def from_dict(cls, data) -> "SagemakerGenerator": """ deserialize_secrets_inplace( data["init_parameters"], - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) @@ -213,9 +189,7 @@ def _get_aws_session( profile_name=aws_profile_name, ) except BotoCoreError as e: - msg = ( - f"Failed to initialize the session with provided AWS credentials: {e}." - ) + msg = f"Failed to initialize the session with provided AWS credentials: {e}." raise AWSConfigurationError(msg) from e @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) @@ -235,8 +209,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): """ generation_kwargs = generation_kwargs or self.generation_kwargs custom_attributes = ";".join( - f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" - for k, v in self.aws_custom_attributes.items() + f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() ) try: body = json.dumps({"inputs": prompt, "parameters": generation_kwargs}) diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index ad689f110..7e23bb7e7 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -4,12 +4,8 @@ import pytest from botocore.exceptions import BotoCoreError from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.generators.amazon_sagemaker import ( - SagemakerGenerator, -) -from haystack_integrations.components.generators.amazon_sagemaker.errors import ( - AWSConfigurationError, -) +from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator +from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 @@ -21,31 +17,11 @@ def test_to_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", "init_parameters": { "model": "model", - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "aws_custom_attributes": {"accept_eula": True}, "generation_kwargs": {"max_new_tokens": 10}, }, @@ -68,31 +44,11 @@ def test_from_dict(set_env_variables, mock_boto3_session): # noqa: ARG001 "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", "init_parameters": { "model": "model", - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, - "aws_secret_access_key": { - "type": "env_var", - "env_vars": ["AWS_SECRET_ACCESS_KEY"], - "strict": False, - }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "aws_custom_attributes": {"accept_eula": True}, "generation_kwargs": {"max_new_tokens": 10}, }, @@ -136,14 +92,10 @@ def test_init_raises_boto_error(set_env_variables, mock_boto3_session): # noqa: SagemakerGenerator(model="test-model") -def test_run_with_list_of_dictionaries( - set_env_variables, mock_boto3_session -): # noqa: ARG001 +def test_run_with_list_of_dictionaries(set_env_variables, mock_boto3_session): # noqa: ARG001 client_mock = Mock() client_mock.invoke_endpoint.return_value = { - "Body": Mock( - read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]' - ) + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') } component = SagemakerGenerator(model="test-model") component.client = client_mock @@ -164,9 +116,7 @@ def test_run_with_list_of_dictionaries( assert response["meta"][0]["other"] == "metadata" -def test_run_with_single_dictionary( - set_env_variables, mock_boto3_session -): # noqa: ARG001 +def test_run_with_single_dictionary(set_env_variables, mock_boto3_session): # noqa: ARG001 client_mock = Mock() client_mock.invoke_endpoint.return_value = { "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') @@ -192,10 +142,7 @@ def test_run_with_single_dictionary( @pytest.mark.skipif( - ( - not os.environ.get("AWS_ACCESS_KEY_ID", None) - or not os.environ.get("AWS_SECRET_ACCESS_KEY", None) - ), + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", ) @pytest.mark.integration @@ -208,9 +155,7 @@ def test_run_with_single_dictionary( ], ) def test_run(model: str): - component = SagemakerGenerator( - model=model, generation_kwargs={"max_new_tokens": 10} - ) + component = SagemakerGenerator(model=model, generation_kwargs={"max_new_tokens": 10}) response = component.run("What's Natural Language Processing?") # check that the component returns the correct ChatMessage response diff --git a/integrations/anthropic/example/documentation_rag_with_claude.py b/integrations/anthropic/example/documentation_rag_with_claude.py index 24bf3ccb1..eb7ec2ad0 100644 --- a/integrations/anthropic/example/documentation_rag_with_claude.py +++ b/integrations/anthropic/example/documentation_rag_with_claude.py @@ -11,12 +11,8 @@ from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator messages = [ - ChatMessage.from_system( - "You are a prompt expert who answers questions based on the given documents." - ), - ChatMessage.from_user( - "Here are the documents: {{documents}} \\n Answer: {{query}}" - ), + ChatMessage.from_system("You are a prompt expert who answers questions based on the given documents."), + ChatMessage.from_user("Here are the documents: {{documents}} \\n Answer: {{query}}"), ] rag_pipeline = Pipeline() @@ -39,12 +35,7 @@ question = "What are the best practices in prompt engineering?" rag_pipeline.run( data={ - "fetcher": { - "urls": ["https://docs.anthropic.com/claude/docs/prompt-engineering"] - }, - "prompt_builder": { - "template_variables": {"query": question}, - "template": messages, - }, + "fetcher": {"urls": ["https://docs.anthropic.com/claude/docs/prompt-engineering"]}, + "prompt_builder": {"template_variables": {"query": question}, "template": messages}, } ) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 60c3540eb..06b3dc353 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -5,12 +5,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from haystack.utils import ( - Secret, - deserialize_callable, - deserialize_secrets_inplace, - serialize_callable, -) +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from anthropic import Anthropic, Stream from anthropic.types import ( @@ -139,11 +134,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: The serialized component as a dictionary. """ - callback_name = ( - serialize_callable(self.streaming_callback) - if self.streaming_callback - else None - ) + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, model=self.model, @@ -166,17 +157,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable( - serialized_callback_handler - ) + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - generation_kwargs: Optional[Dict[str, Any]] = None, - ): + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -192,9 +177,7 @@ def run( # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - filtered_generation_kwargs = { - k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS - } + filtered_generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS) if disallowed_params: logger.warning( @@ -206,17 +189,11 @@ def run( anthropic_formatted_messages = self._convert_to_anthropic_format(messages) # system message provided by the user overrides the system message from the self.generation_kwargs - system = ( - messages[0].content - if messages and messages[0].is_from(ChatRole.SYSTEM) - else None - ) + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None if system: anthropic_formatted_messages = anthropic_formatted_messages[1:] - response: Union[ - Message, Stream[MessageStreamEvent] - ] = self.client.messages.create( + response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), system=system if system else filtered_generation_kwargs.pop("system", ""), model=self.model, @@ -238,35 +215,21 @@ def run( chunk_delta: StreamingChunk = self._build_chunk(stream_event.delta) chunks.append(chunk_delta) if self.streaming_callback: - self.streaming_callback( - chunk_delta - ) # invoke callback with the chunk_delta + self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta if isinstance(stream_event, MessageDeltaEvent): # capture stop reason and stop sequence delta = stream_event completions = [self._connect_chunks(chunks, start_event, delta)] # if streaming is disabled, the response is an Anthropic Message elif isinstance(response, Message): - has_tools_msgs = any( - isinstance(content_block, ToolUseBlock) - for content_block in response.content - ) + has_tools_msgs = any(isinstance(content_block, ToolUseBlock) for content_block in response.content) if has_tools_msgs and self.ignore_tools_thinking_messages: - response.content = [ - block - for block in response.content - if isinstance(block, ToolUseBlock) - ] - completions = [ - self._build_message(content_block, response) - for content_block in response.content - ] + response.content = [block for block in response.content if isinstance(block, ToolUseBlock)] + completions = [self._build_message(content_block, response) for content_block in response.content] return {"replies": completions} - def _build_message( - self, content_block: Union[TextBlock, ToolUseBlock], message: Message - ) -> ChatMessage: + def _build_message(self, content_block: Union[TextBlock, ToolUseBlock], message: Message) -> ChatMessage: """ Converts the non-streaming Anthropic Message to a ChatMessage. :param content_block: The content block of the message. @@ -276,9 +239,7 @@ def _build_message( if isinstance(content_block, TextBlock): chat_message = ChatMessage.from_assistant(content_block.text) else: - chat_message = ChatMessage.from_assistant( - json.dumps(content_block.model_dump(mode="json")) - ) + chat_message = ChatMessage.from_assistant(json.dumps(content_block.model_dump(mode="json"))) chat_message.meta.update( { "model": message.model, @@ -289,9 +250,7 @@ def _build_message( ) return chat_message - def _convert_to_anthropic_format( - self, messages: List[ChatMessage] - ) -> List[Dict[str, Any]]: + def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: """ Converts the list of ChatMessage to the list of messages in the format expected by the Anthropic API. :param messages: The list of ChatMessage. @@ -300,17 +259,12 @@ def _convert_to_anthropic_format( anthropic_formatted_messages = [] for m in messages: message_dict = dataclasses.asdict(m) - filtered_message = { - k: v for k, v in message_dict.items() if k in {"role", "content"} and v - } + filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} anthropic_formatted_messages.append(filtered_message) return anthropic_formatted_messages def _connect_chunks( - self, - chunks: List[StreamingChunk], - message_start: MessageStartEvent, - delta: MessageDeltaEvent, + self, chunks: List[StreamingChunk], message_start: MessageStartEvent, delta: MessageDeltaEvent ) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. @@ -319,17 +273,13 @@ def _connect_chunks( :param delta: The MessageDeltaEvent. :returns: The complete ChatMessage. """ - complete_response = ChatMessage.from_assistant( - "".join([chunk.content for chunk in chunks]) - ) + complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) complete_response.meta.update( { "model": self.model, "index": 0, "finish_reason": delta.delta.stop_reason if delta else "end_turn", - "usage": {**dict(message_start.message.usage, **dict(delta.usage))} - if delta and message_start - else {}, + "usage": {**dict(message_start.message.usage, **dict(delta.usage))} if delta and message_start else {}, } ) return complete_response diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py index e85612856..4cb8fd3e6 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py @@ -2,12 +2,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import StreamingChunk -from haystack.utils import ( - Secret, - deserialize_callable, - deserialize_secrets_inplace, - serialize_callable, -) +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from anthropic import Anthropic, Stream from anthropic.types import ( @@ -94,11 +89,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: The serialized component as a dictionary. """ - callback_name = ( - serialize_callable(self.streaming_callback) - if self.streaming_callback - else None - ) + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, model=self.model, @@ -121,9 +112,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable( - serialized_callback_handler - ) + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) @@ -139,9 +128,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - filtered_generation_kwargs = { - k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS - } + filtered_generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS) if disallowed_params: logger.warning( @@ -149,13 +136,9 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): f"Allowed parameters are {self.ALLOWED_PARAMS}." ) - response: Union[ - Message, Stream[MessageStreamEvent] - ] = self.client.messages.create( + response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), - system=self.system_prompt - if self.system_prompt - else filtered_generation_kwargs.pop("system", ""), + system=self.system_prompt if self.system_prompt else filtered_generation_kwargs.pop("system", ""), model=self.model, messages=[MessageParam(content=prompt, role="user")], stream=self.streaming_callback is not None, @@ -173,14 +156,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): # capture start message to count input tokens start_event = stream_event if isinstance(stream_event, ContentBlockDeltaEvent): - chunk_delta: StreamingChunk = StreamingChunk( - content=stream_event.delta.text - ) + chunk_delta: StreamingChunk = StreamingChunk(content=stream_event.delta.text) chunks.append(chunk_delta) if self.streaming_callback: - self.streaming_callback( - chunk_delta - ) # invoke callback with the chunk_delta + self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta if isinstance(stream_event, MessageDeltaEvent): # capture stop reason and stop sequence delta = stream_event @@ -190,9 +169,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): "model": self.model, "index": 0, "finish_reason": delta.delta.stop_reason if delta else "end_turn", - "usage": {**dict(start_event.message.usage, **dict(delta.usage))} - if delta and start_event - else {}, + "usage": {**dict(start_event.message.usage, **dict(delta.usage))} if delta and start_event else {}, } ) # if streaming is disabled, the response is an Anthropic Message diff --git a/integrations/anthropic/tests/conftest.py b/integrations/anthropic/tests/conftest.py index 43b2a7609..e70223143 100644 --- a/integrations/anthropic/tests/conftest.py +++ b/integrations/anthropic/tests/conftest.py @@ -9,9 +9,7 @@ def mock_chat_completion(): """ Mock the OpenAI API completion response and reuse it for tests """ - with patch( - "anthropic.resources.messages.Messages.create" - ) as mock_chat_completion_create: + with patch("anthropic.resources.messages.Messages.create") as mock_chat_completion_create: completion = Message( id="foo", content=[{"type": "text", "text": "Hello, world!"}], diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 9a41af1c2..3ffa24c94 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -13,9 +13,7 @@ @pytest.fixture def chat_messages(): return [ - ChatMessage.from_system( - "\\nYou are a helpful assistant, be super brief in your responses." - ), + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), ChatMessage.from_user("What's the capital of France?"), ] @@ -32,9 +30,7 @@ def test_init_default(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): AnthropicChatGenerator() def test_init_with_parameters(self): @@ -48,10 +44,7 @@ def test_init_with_parameters(self): assert component.client.api_key == "test-api-key" assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.ignore_tools_thinking_messages is False def test_to_dict_default(self, monkeypatch): @@ -61,11 +54,7 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": None, "generation_kwargs": {}, @@ -87,10 +76,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, }, } @@ -106,17 +92,10 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "tests.test_chat_generator.", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, }, } @@ -126,27 +105,17 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, }, } component = AnthropicChatGenerator.from_dict(data) assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY") def test_from_dict_fail_wo_env_var(self, monkeypatch): @@ -154,23 +123,14 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, }, } - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): AnthropicChatGenerator.from_dict(data) def test_run(self, chat_messages, mock_chat_completion): @@ -186,8 +146,7 @@ def test_run(self, chat_messages, mock_chat_completion): def test_run_with_params(self, chat_messages, mock_chat_completion): component = AnthropicChatGenerator( - api_key=Secret.from_token("test-api-key"), - generation_kwargs={"max_tokens": 10, "temperature": 0.5}, + api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run(chat_messages) @@ -228,16 +187,10 @@ def test_default_inference_params(self, chat_messages): assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "paris" in first_reply.content.lower() - ), "First reply does not contain 'paris'" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.skipif( @@ -261,24 +214,16 @@ def streaming_callback(chunk: StreamingChunk): response = client.run(chat_messages) assert streaming_callback_called, "Streaming callback was not called" - assert ( - paris_found_in_response - ), "The streaming callback response did not contain 'paris'" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "paris" in first_reply.content.lower() - ), "First reply does not contain 'paris'" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.skipif( @@ -294,10 +239,7 @@ def test_tools_use(self): "input_schema": { "type": "object", "properties": { - "ticker": { - "type": "string", - "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", - } + "ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."} }, "required": ["ticker"], }, @@ -312,16 +254,10 @@ def test_tools_use(self): assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "get_stock_price" in first_reply.content.lower() - ), "First reply does not contain get_stock_price" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" assert first_reply.meta, "First reply has no metadata" fc_response = json.loads(first_reply.content) assert "name" in fc_response, "First reply does not contain name of the tool" diff --git a/integrations/anthropic/tests/test_generator.py b/integrations/anthropic/tests/test_generator.py index b55ef9ef0..029cd3920 100644 --- a/integrations/anthropic/tests/test_generator.py +++ b/integrations/anthropic/tests/test_generator.py @@ -20,9 +20,7 @@ def test_init_default(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): AnthropicGenerator() def test_init_with_parameters(self): @@ -35,10 +33,7 @@ def test_init_with_parameters(self): assert component.client.api_key == "test-api-key" assert component.model == "claude-3-sonnet-20240229" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") @@ -47,11 +42,7 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.anthropic.generator.AnthropicGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-sonnet-20240229", "streaming_callback": None, "system_prompt": None, @@ -75,10 +66,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "claude-3-sonnet-20240229", "system_prompt": "test-prompt", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -93,18 +81,11 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.anthropic.generator.AnthropicGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-sonnet-20240229", "streaming_callback": "tests.test_generator.", "system_prompt": None, - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -113,28 +94,18 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.anthropic.generator.AnthropicGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-sonnet-20240229", "system_prompt": "test-prompt", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } component = AnthropicGenerator.from_dict(data) assert component.model == "claude-3-sonnet-20240229" assert component.streaming_callback is print_streaming_chunk assert component.system_prompt == "test-prompt" - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY") def test_from_dict_fail_wo_env_var(self, monkeypatch): @@ -142,23 +113,14 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.anthropic.generator.AnthropicGenerator", "init_parameters": { - "api_key": { - "env_vars": ["ANTHROPIC_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, "model": "claude-3-sonnet-20240229", "system_prompt": "test-prompt", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): AnthropicGenerator.from_dict(data) def test_run(self, mock_chat_completion): @@ -178,8 +140,7 @@ def test_run(self, mock_chat_completion): def test_run_with_params(self, mock_chat_completion): component = AnthropicGenerator( - api_key=Secret.from_token("test-api-key"), - generation_kwargs={"max_tokens": 10, "temperature": 0.5}, + api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run("What is the capital of France?") @@ -255,9 +216,7 @@ def streaming_callback(chunk: StreamingChunk): response = client.run("What is the capital of France?") assert streaming_callback_called, "Streaming callback was not called" - assert ( - paris_found_in_response - ), "The streaming callback response did not contain 'paris'" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index 41832936c..00da589bb 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -4,10 +4,7 @@ from haystack import Pipeline from haystack.components.converters import TextFileToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter from haystack.components.routers import FileTypeRouter from haystack.components.writers import DocumentWriter @@ -21,10 +18,7 @@ HERE = Path(__file__).resolve().parent -file_paths = [ - HERE / "data" / Path(name) - for name in os.listdir("integrations/astra/examples/data") -] +file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")] logger.info(file_paths) collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") @@ -41,26 +35,15 @@ # Create components and an indexing pipeline that converts txt files to documents, # cleans and splits them, and indexes them p = Pipeline() -p.add_component( - instance=FileTypeRouter(mime_types=["text/plain", "application/pdf"]), - name="file_type_router", -) +p.add_component(instance=FileTypeRouter(mime_types=["text/plain", "application/pdf"]), name="file_type_router") p.add_component(instance=TextFileToDocument(), name="text_file_converter") p.add_component(instance=DocumentCleaner(), name="cleaner") +p.add_component(instance=DocumentSplitter(split_by="word", split_length=150, split_overlap=30), name="splitter") p.add_component( - instance=DocumentSplitter(split_by="word", split_length=150, split_overlap=30), - name="splitter", -) -p.add_component( - instance=SentenceTransformersDocumentEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" - ), + instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) -p.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), - name="writer", -) +p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer") p.connect("file_type_router.text/plain", "text_file_converter.sources") p.connect("text_file_converter.documents", "cleaner.documents") @@ -73,9 +56,7 @@ # Create a querying pipeline on the indexed data q = Pipeline() q.add_component( - instance=SentenceTransformersTextEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" - ), + instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) q.add_component("retriever", AstraEmbeddingRetriever(document_store)) @@ -122,9 +103,7 @@ )}""" ) -document_store.delete_documents( - ["92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"] -) +document_store.delete_documents(["92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"]) documents_count = document_store.count_documents() logger.info(f"count: {document_store.count_documents()}") diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index 12e8071d1..826d85956 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -4,10 +4,7 @@ from haystack import Document, Pipeline from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator from haystack.components.writers import DocumentWriter from haystack.document_stores.types import DuplicatePolicy @@ -58,15 +55,10 @@ ] p = Pipeline() p.add_component( - instance=SentenceTransformersDocumentEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" - ), + instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) -p.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), - name="writer", -) +p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer") p.connect("embedder.documents", "writer.documents") p.run({"embedder": {"documents": documents}}) @@ -75,17 +67,11 @@ # Construct rag pipeline rag_pipeline = Pipeline() rag_pipeline.add_component( - instance=SentenceTransformersTextEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" - ), + instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), name="embedder", ) -rag_pipeline.add_component( - instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever" -) -rag_pipeline.add_component( - instance=PromptBuilder(template=prompt_template), name="prompt_builder" -) +rag_pipeline.add_component(instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever") +rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("embedder", "retriever") diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index cdf4f81ab..cfa45e81f 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -50,9 +50,7 @@ def __init__( self.top_k = top_k self.document_store = document_store self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) if not isinstance(document_store, AstraDocumentStore): @@ -60,12 +58,7 @@ def __init__( raise Exception(message) @component.output_types(documents=List[Document]) - def run( - self, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ): + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AstraDocumentStore. :param query_embedding: floats representing the query embedding @@ -79,11 +72,7 @@ def run( filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k - return { - "documents": self.document_store.search( - query_embedding, top_k, filters=filters - ) - } + return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)} def to_dict(self) -> Dict[str, Any]: """ @@ -110,14 +99,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": :returns: Deserialized component. """ - document_store = AstraDocumentStore.from_dict( - data["init_parameters"]["document_store"] - ) + document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 9de2b2d16..5a88a0fe9 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -85,17 +85,11 @@ def __init__( except APIRequestError: # possibly the collection is preexisting and has legacy # indexing settings: verify - get_coll_response = self._astra_db.get_collections( - options={"explain": True} - ) + get_coll_response = self._astra_db.get_collections(options={"explain": True}) collections = (get_coll_response["status"] or {}).get("collections") or [] - preexisting = [ - collection - for collection in collections - if collection["name"] == collection_name - ] + preexisting = [collection for collection in collections if collection["name"] == collection_name] if preexisting: pre_collection = preexisting[0] @@ -155,9 +149,7 @@ def __init__( def query( self, vector: Optional[List[float]] = None, - query_filter: Optional[ - Dict[str, Union[str, float, int, bool, List, dict]] - ] = None, + query_filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, top_k: Optional[int] = None, include_metadata: Optional[bool] = None, include_values: Optional[bool] = None, @@ -183,9 +175,7 @@ def query( # include_metadata means return all columns in the table (including text that got embedded) # include_values means return the vector of the embedding for the searched items - formatted_response = self._format_query_response( - responses, include_metadata, include_values - ) + formatted_response = self._format_query_response(responses, include_metadata, include_values) return formatted_response @@ -206,19 +196,14 @@ def _format_query_response(responses, include_metadata, include_values): score = response.pop("$similarity", None) text = response.pop("content", None) values = response.pop("$vector", None) if include_values else [] - metadata = ( - response if include_metadata else {} - ) # Add all remaining fields to the metadata + metadata = response if include_metadata else {} # Add all remaining fields to the metadata rsp = Response(_id, text, values, metadata, score) final_res.append(rsp) return QueryResponse(final_res) def _query(self, vector, top_k, filters=None): - query = { - "sort": {"$vector": vector}, - "options": {"limit": top_k, "includeSimilarity": True}, - } + query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} if filters is not None: query["filter"] = filters @@ -267,9 +252,7 @@ def batch_generator(chunks, batch_size): if docs: document_batch.extend(docs) - formatted_docs = self._format_query_response( - document_batch, include_metadata=True, include_values=True - ) + formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True) return formatted_docs @@ -312,14 +295,8 @@ def update_document(self, document: Dict, id_key: str): document[id_key] = document_id if "status" in response_dict and "errors" not in response_dict: - if ( - "matchedCount" in response_dict["status"] - and "modifiedCount" in response_dict["status"] - ): - if ( - response_dict["status"]["matchedCount"] == 1 - and response_dict["status"]["modifiedCount"] == 1 - ): + if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: + if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: return True logger.warning(f"Documents {document_id} not updated in Astra DB.") diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 95fa7ffb1..1dea6e08b 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -49,9 +49,7 @@ class AstraDocumentStore: def __init__( self, - api_endpoint: Secret = Secret.from_env_var( - "ASTRA_DB_API_ENDPOINT" - ), # noqa: B008 + api_endpoint: Secret = Secret.from_env_var("ASTRA_DB_API_ENDPOINT"), # noqa: B008 token: Secret = Secret.from_env_var("ASTRA_DB_APPLICATION_TOKEN"), # noqa: B008 collection_name: str = "documents", embedding_dimension: int = 768, @@ -130,9 +128,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": :returns: Deserialized component. """ - deserialize_secrets_inplace( - data["init_parameters"], keys=["api_endpoint", "token"] - ) + deserialize_secrets_inplace(data["init_parameters"], keys=["api_endpoint", "token"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: @@ -177,10 +173,7 @@ def write_documents( :raises Exception: if the document ID is not a string or if `id` and `_id` are both present in the document. """ if policy is None or policy == DuplicatePolicy.NONE: - if ( - self.duplicates_policy is not None - and self.duplicates_policy != DuplicatePolicy.NONE - ): + if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: policy = self.duplicates_policy else: policy = DuplicatePolicy.SKIP @@ -252,9 +245,7 @@ def _convert_input_document(document: Union[dict, Document]): for batch in _batches(new_documents, batch_size): inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) - logger.info( - f"write_documents inserted documents with id {inserted_ids}" - ) + logger.info(f"write_documents inserted documents with id {inserted_ids}") else: logger.warning("No documents written. Argument policy set to SKIP") @@ -263,9 +254,7 @@ def _convert_input_document(document: Union[dict, Document]): for batch in _batches(new_documents, batch_size): inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) - logger.info( - f"write_documents inserted documents with id {inserted_ids}" - ) + logger.info(f"write_documents inserted documents with id {inserted_ids}") else: logger.warning("No documents written. Argument policy set to OVERWRITE") @@ -285,9 +274,7 @@ def _convert_input_document(document: Union[dict, Document]): for batch in _batches(new_documents, batch_size): inserted_ids = self.index.insert(batch) # type: ignore insertion_counter = insertion_counter + len(inserted_ids) - logger.info( - f"write_documents inserted documents with id {inserted_ids}" - ) + logger.info(f"write_documents inserted documents with id {inserted_ids}") else: logger.warning("No documents written. Argument policy set to FAIL") @@ -301,9 +288,7 @@ def count_documents(self) -> int: """ return self.index.count_documents() - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns at most 1000 documents that match the filter. @@ -340,11 +325,7 @@ def filter_documents( else: converted_filters = _convert_filters(filters) results = self.index.query( - vector=vector, - query_filter=converted_filters, - top_k=1000, - include_values=True, - include_metadata=True, + vector=vector, query_filter=converted_filters, top_k=1000, include_values=True, include_metadata=True ) documents = self._get_result_to_documents(results) return documents @@ -397,10 +378,7 @@ def get_document_by_id(self, document_id: str) -> Document: return ret[0] def search( - self, - query_embedding: List[float], - top_k: int, - filters: Optional[Dict[str, Any]] = None, + self, query_embedding: List[float], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[Document]: """ Perform a search for a list of queries. @@ -425,11 +403,7 @@ def search( return result - def delete_documents( - self, - document_ids: Optional[List[str]] = None, - delete_all: Optional[bool] = None, - ) -> None: + def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: Optional[bool] = None) -> None: """ Deletes documents from the document store. diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 0fed13f14..61f3e5402 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -17,9 +17,7 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: return _parse_logical_condition(filters) -def _convert_filters( - filters: Optional[Dict[str, Any]] = None -) -> Optional[Dict[str, Any]]: +def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: """ Convert haystack filters to astra filter string capturing all boolean operators """ @@ -121,9 +119,7 @@ def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ] ``` """ - range_conditions = [ - next(iter(c["range"].items())) for c in conditions if "range" in c - ] + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] if range_conditions: conditions = [c for c in conditions if "range" not in c] range_conditions_dict: Dict[str, Any] = {} diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index a7449e49b..df181ad8c 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -27,9 +27,7 @@ def test_init_is_lazy(_mock_client, mock_auth): # noqa def test_namespace_init(mock_auth): # noqa - with mock.patch( - "haystack_integrations.document_stores.astra.astra_client.AstraDB" - ) as client: + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: _ = AstraDocumentStore().index assert "namespace" in client.call_args.kwargs assert client.call_args.kwargs["namespace"] is None @@ -43,10 +41,7 @@ def test_to_dict(mock_auth): # noqa with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): ds = AstraDocumentStore() result = ds.to_dict() - assert ( - result["type"] - == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" - ) + assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" assert set(result["init_parameters"]) == { "api_endpoint", "token", @@ -60,13 +55,9 @@ def test_to_dict(mock_auth): # noqa @pytest.mark.integration @pytest.mark.skipif( - os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", - reason="ASTRA_DB_APPLICATION_TOKEN env var not set", -) -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", - reason="ASTRA_DB_API_ENDPOINT env var not set", + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" ) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") class TestDocumentStore(DocumentStoreBaseTests): """ Common test cases will be provided by `DocumentStoreBaseTests` but @@ -89,9 +80,7 @@ def run_before_and_after_tests(self, document_store: AstraDocumentStore): document_store.delete_documents(delete_all=True) assert document_store.count_documents() == 0 - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. This is used in every test, if a Document Store implementation has a different behaviour @@ -108,9 +97,7 @@ def assert_documents_are_equal( def test_comparison_equal_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": "==", "value": None} - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": None}) # Astra does not support filtering on None, it returns empty list self.assert_documents_are_equal(result, []) @@ -122,17 +109,9 @@ def test_write_documents(self, document_store: AstraDocumentStore): doc1 = Document(id="1", content="test doc 1") doc2 = Document(id="1", content="test doc 2") - assert ( - document_store.write_documents([doc2], policy=DuplicatePolicy.OVERWRITE) - == 1 - ) + assert document_store.write_documents([doc2], policy=DuplicatePolicy.OVERWRITE) == 1 self.assert_documents_are_equal(document_store.filter_documents(), [doc2]) - assert ( - document_store.write_documents( - documents=[doc1], policy=DuplicatePolicy.OVERWRITE - ) - == 1 - ) + assert document_store.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) == 1 self.assert_documents_are_equal(document_store.filter_documents(), [doc1]) def test_write_documents_skip_duplicates(self, document_store: AstraDocumentStore): @@ -142,9 +121,7 @@ def test_write_documents_skip_duplicates(self, document_store: AstraDocumentStor ] assert document_store.write_documents(docs, policy=DuplicatePolicy.SKIP) == 1 - def test_delete_documents_non_existing_document( - self, document_store: AstraDocumentStore - ): + def test_delete_documents_non_existing_document(self, document_store: AstraDocumentStore): """ Test delete_documents() doesn't delete any Document when called with non existing id. """ @@ -158,9 +135,7 @@ def test_delete_documents_non_existing_document( # No Document has been deleted assert document_store.count_documents() == 1 - def test_delete_documents_more_than_twenty_delete_all( - self, document_store: AstraDocumentStore - ): + def test_delete_documents_more_than_twenty_delete_all(self, document_store: AstraDocumentStore): """ Test delete_documents() deletes all documents when called on an Astra DB with more than 20 documents. Twenty documents is the maximum number of deleted @@ -177,9 +152,7 @@ def test_delete_documents_more_than_twenty_delete_all( assert document_store.count_documents() == 0 - def test_delete_documents_more_than_twenty_delete_ids( - self, document_store: AstraDocumentStore - ): + def test_delete_documents_more_than_twenty_delete_ids(self, document_store: AstraDocumentStore): """ Test delete_documents() deletes all documents when called on an Astra DB with more than 20 documents. Twenty documents is the maximum number of deleted @@ -207,11 +180,7 @@ def test_filter_documents_nested_filters(self, document_store, filterable_docs): { "operator": "OR", "conditions": [ - { - "field": "meta.chapter", - "operator": "==", - "value": "abstract", - }, + {"field": "meta.chapter", "operator": "==", "value": "abstract"}, {"field": "meta.chapter", "operator": "==", "value": "intro"}, ], }, @@ -227,10 +196,7 @@ def test_filter_documents_nested_filters(self, document_store, filterable_docs): d for d in filterable_docs if d.meta.get("page") == "100" - and ( - d.meta.get("chapter") == "abstract" - or d.meta.get("chapter") == "intro" - ) + and (d.meta.get("chapter") == "abstract" or d.meta.get("chapter") == "intro") ], ) @@ -255,21 +221,15 @@ def test_comparison_not_in(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $nin.") - def test_comparison_not_in_with_with_non_list( - self, document_store, filterable_docs - ): + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $nin.") - def test_comparison_not_in_with_with_non_list_iterable( - self, document_store, filterable_docs - ): + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gt.") - def test_comparison_greater_than_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gt.") @@ -277,9 +237,7 @@ def test_comparison_greater_than_with_string(self, document_store, filterable_do pass @pytest.mark.skip(reason="Unsupported filter operator $gt.") - def test_comparison_greater_than_with_dataframe( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gt.") @@ -299,33 +257,23 @@ def test_comparison_greater_than_equal(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gte.") - def test_comparison_greater_than_equal_with_none( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gte.") - def test_comparison_greater_than_equal_with_list( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_list(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gte.") - def test_comparison_greater_than_equal_with_dataframe( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gte.") - def test_comparison_greater_than_equal_with_string( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_string(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $gte.") - def test_comparison_greater_than_equal_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") @@ -333,33 +281,23 @@ def test_comparison_less_than_equal(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") - def test_comparison_less_than_equal_with_string( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_string(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") - def test_comparison_less_than_equal_with_dataframe( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") - def test_comparison_less_than_equal_with_list( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_list(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") - def test_comparison_less_than_equal_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lte.") - def test_comparison_less_than_equal_with_none( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): pass @pytest.mark.skip(reason="Unsupported filter operator $lt.") diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index b2ed04294..4ffe30919 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -12,44 +12,30 @@ @patch.dict( "os.environ", - { - "ASTRA_DB_APPLICATION_TOKEN": "fake-token", - "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com", - }, + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_init(*_): ds = AstraDocumentStore() - retriever = AstraEmbeddingRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace" - ) + retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace") assert retriever.filters == {"foo": "bar"} assert retriever.top_k == 99 assert retriever.document_store == ds assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = AstraEmbeddingRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy=FilterPolicy.MERGE - ) + retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=FilterPolicy.MERGE) assert retriever.filter_policy == FilterPolicy.MERGE with pytest.raises(ValueError): - AstraEmbeddingRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown" - ) + AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") with pytest.raises(ValueError): - AstraEmbeddingRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy=None - ) + AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=None) @patch.dict( "os.environ", - { - "ASTRA_DB_APPLICATION_TOKEN": "fake-token", - "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com", - }, + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_to_json(*_): @@ -65,16 +51,8 @@ def test_retriever_to_json(*_): "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "api_endpoint": { - "type": "env_var", - "env_vars": ["ASTRA_DB_API_ENDPOINT"], - "strict": True, - }, - "token": { - "type": "env_var", - "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], - "strict": True, - }, + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, "collection_name": "documents", "embedding_dimension": 768, "duplicates_policy": "NONE", @@ -88,10 +66,7 @@ def test_retriever_to_json(*_): @patch.dict( "os.environ", - { - "ASTRA_DB_APPLICATION_TOKEN": "fake-token", - "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com", - }, + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): @@ -104,16 +79,8 @@ def test_retriever_from_json(*_): "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "api_endpoint": { - "type": "env_var", - "env_vars": ["ASTRA_DB_API_ENDPOINT"], - "strict": True, - }, - "token": { - "type": "env_var", - "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], - "strict": True, - }, + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, "collection_name": "documents", "embedding_dimension": 768, "duplicates_policy": "NONE", @@ -129,10 +96,7 @@ def test_retriever_from_json(*_): @patch.dict( "os.environ", - { - "ASTRA_DB_APPLICATION_TOKEN": "fake-token", - "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com", - }, + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json_no_filter_policy(*_): @@ -144,16 +108,8 @@ def test_retriever_from_json_no_filter_policy(*_): "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "api_endpoint": { - "type": "env_var", - "env_vars": ["ASTRA_DB_API_ENDPOINT"], - "strict": True, - }, - "token": { - "type": "env_var", - "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], - "strict": True, - }, + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, "collection_name": "documents", "embedding_dimension": 768, "duplicates_policy": "NONE", diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 777321ff8..71ac3457e 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -60,9 +60,7 @@ def __init__( self.top_k = top_k self.document_store = document_store self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) @component.output_types(documents=List[Document]) @@ -100,16 +98,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": :returns: Deserialized component. """ - document_store = ChromaDocumentStore.from_dict( - data["init_parameters"]["document_store"] - ) + document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @@ -160,8 +154,4 @@ def run( top_k = top_k or self.top_k query_embeddings = [query_embedding] - return { - "documents": self.document_store.search_embeddings( - query_embeddings, top_k, filters - )[0] - } + return {"documents": self.document_store.search_embeddings(query_embeddings, top_k, filters)[0]} diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 5324921c5..3ea84780f 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -7,12 +7,7 @@ import chromadb import numpy as np -from chromadb.api.types import ( - GetResult, - QueryResult, - validate_where, - validate_where_document, -) +from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document from haystack.document_stores.types import DuplicatePolicy @@ -90,18 +85,14 @@ def __init__( else: self._chroma_client = chromadb.PersistentClient(path=persist_path) - embedding_func = get_embedding_function( - embedding_function, **embedding_function_params - ) + embedding_func = get_embedding_function(embedding_function, **embedding_function_params) metadata = metadata or {} if "hnsw:space" not in metadata: metadata["hnsw:space"] = distance_function if collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection( - collection_name, embedding_function=embedding_func - ) + self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) if metadata != self._collection.metadata: logger.warning( @@ -122,9 +113,7 @@ def count_documents(self) -> int: """ return self._collection.count() - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -210,9 +199,7 @@ def filter_documents( return self._get_result_to_documents(result) - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> int: """ Writes (or overwrites) documents into the store. @@ -229,9 +216,7 @@ def write_documents( """ for doc in documents: if not isinstance(doc, Document): - msg = ( - "param 'documents' must contain a list of objects of type Document" - ) + msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(msg) if doc.content is None: @@ -257,9 +242,7 @@ def write_documents( "These items will be discarded. Supported types are: %s.", doc.id, ", ".join(discarded_keys), - ", ".join( - [t.__name__ for t in SUPPORTED_TYPES_FOR_METADATA_VALUES] - ), + ", ".join([t.__name__ for t in SUPPORTED_TYPES_FOR_METADATA_VALUES]), ) if valid_meta: @@ -288,9 +271,7 @@ def delete_documents(self, document_ids: List[str]) -> None: """ self._collection.delete(ids=document_ids) - def search( - self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None - ) -> List[List[Document]]: + def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: """Search the documents in the store using the provided text queries. :param queries: the list of queries to search for. @@ -317,10 +298,7 @@ def search( return self._query_result_to_documents(results) def search_embeddings( - self, - query_embeddings: List[List[float]], - top_k: int, - filters: Optional[Dict[str, Any]] = None, + self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[List[Document]]: """ Perform vector search on the stored document, pass the embeddings of the queries instead of their text. @@ -379,9 +357,7 @@ def to_dict(self) -> Dict[str, Any]: ) @staticmethod - def _normalize_filters( - filters: Dict[str, Any] - ) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]: + def _normalize_filters(filters: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]: """ Translate Haystack filters to Chroma filters. It returns three dictionaries, to be passed to `ids`, `where` and `where_document` respectively. diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 2e353520d..b05c9ccfc 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -26,16 +26,12 @@ class _TestEmbeddingFunction(EmbeddingFunction): vectors in unit tests. """ - def __call__( - self, input: Documents - ) -> Embeddings: # noqa - chroma will inspect the signature, it must match + def __call__(self, input: Documents) -> Embeddings: # noqa - chroma will inspect the signature, it must match # embed the documents somehow return [np.random.default_rng().uniform(-1, 1, 768).tolist()] -class TestDocumentStore( - CountDocumentsTest, DeleteDocumentsTest, LegacyFilterDocumentsTest -): +class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, LegacyFilterDocumentsTest): """ Common test cases will be provided by `DocumentStoreBaseTests` but you can add more to this class. @@ -51,13 +47,9 @@ def document_store(self) -> ChromaDocumentStore: "haystack_integrations.document_stores.chroma.document_store.get_embedding_function" ) as get_func: get_func.return_value = _TestEmbeddingFunction() - return ChromaDocumentStore( - embedding_function="test_function", collection_name=str(uuid.uuid1()) - ) + return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1())) - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. This is used in every test, if a Document Store implementation has a different behaviour @@ -73,9 +65,7 @@ def assert_documents_are_equal( assert doc_received.content == doc_expected.content assert doc_received.meta == doc_expected.meta - def test_ne_filter( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_ne_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): """ We customize this test because Chroma consider "not equal" true when a field is missing @@ -83,8 +73,7 @@ def test_ne_filter( document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"page": {"$ne": "100"}}) self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("page", "100") != "100"], + result, [doc for doc in filterable_docs if doc.meta.get("page", "100") != "100"] ) def test_delete_empty(self, document_store: ChromaDocumentStore): @@ -118,9 +107,7 @@ def test_search(self): assert len(result) == 1 assert result[0][0].content == "Third document" - def test_write_documents_unsupported_meta_values( - self, document_store: ChromaDocumentStore - ): + def test_write_documents_unsupported_meta_values(self, document_store: ChromaDocumentStore): """ Unsupported meta values should be removed from the documents before writing them to the database """ @@ -145,9 +132,7 @@ def test_write_documents_unsupported_meta_values( @pytest.mark.integration def test_to_json(self, request): ds = ChromaDocumentStore( - collection_name=request.node.name, - embedding_function="HuggingFaceEmbeddingFunction", - api_key="1234567890", + collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" ) ds_dict = ds.to_dict() assert ds_dict == { @@ -244,69 +229,43 @@ def test_metadata_initialization(self, caplog): assert new_store._collection.metadata["hnsw:space"] == "ip" @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") - def test_filter_document_dataframe( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - def test_eq_filter_table( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on embedding value is not supported.") - def test_eq_filter_embedding( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip( - reason="$in operator is not supported. Filter on table contents is not supported." - ) - def test_in_filter_table( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + @pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.") + def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$in operator is not supported.") - def test_in_filter_embedding( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_in_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - def test_ne_filter_table( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_ne_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on embedding value is not supported.") - def test_ne_filter_embedding( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_ne_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip( - reason="$nin operator is not supported. Filter on table contents is not supported." - ) - def test_nin_filter_table( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + @pytest.mark.skip(reason="$nin operator is not supported. Filter on table contents is not supported.") + def test_nin_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip( - reason="$nin operator is not supported. Filter on embedding value is not supported." - ) - def test_nin_filter_embedding( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + @pytest.mark.skip(reason="$nin operator is not supported. Filter on embedding value is not supported.") + def test_nin_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="$nin operator is not supported.") - def test_nin_filter( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_nin_filter(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") @@ -322,45 +281,31 @@ def test_filter_simple_explicit_and_with_list( pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_simple_implicit_and( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_implicit_and( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_simple_or( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_or( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_or(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter on table contents is not supported.") - def test_filter_nested_and_or_explicit( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_and_or_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_and_or_implicit( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_and_or_implicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_or_and( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_or_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @pytest.mark.skip(reason="Filter syntax not supported.") diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index 1a4e73cd5..f0e71828d 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -10,27 +10,19 @@ @pytest.mark.integration def test_retriever_init(request): ds = ChromaDocumentStore( - collection_name=request.node.name, - embedding_function="HuggingFaceEmbeddingFunction", - api_key="1234567890", - ) - retriever = ChromaQueryTextRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace" + collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" ) + retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace") assert retriever.filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - ChromaQueryTextRetriever( - ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown" - ) + ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") @pytest.mark.integration def test_retriever_to_json(request): ds = ChromaDocumentStore( - collection_name=request.node.name, - embedding_function="HuggingFaceEmbeddingFunction", - api_key="1234567890", + collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" ) retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { @@ -75,12 +67,8 @@ def test_retriever_from_json(request): } retriever = ChromaQueryTextRetriever.from_dict(data) assert retriever.document_store._collection_name == request.node.name - assert ( - retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" - ) - assert retriever.document_store._embedding_function_params == { - "api_key": "1234567890" - } + assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" + assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"} assert retriever.document_store._persist_path == "." assert retriever.filters == {"bar": "baz"} assert retriever.top_k == 42 @@ -108,15 +96,9 @@ def test_retriever_from_json_no_filter_policy(request): } retriever = ChromaQueryTextRetriever.from_dict(data) assert retriever.document_store._collection_name == request.node.name - assert ( - retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" - ) - assert retriever.document_store._embedding_function_params == { - "api_key": "1234567890" - } + assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" + assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"} assert retriever.document_store._persist_path == "." assert retriever.filters == {"bar": "baz"} assert retriever.top_k == 42 - assert ( - retriever.filter_policy == FilterPolicy.REPLACE - ) # default even if not specified + assert retriever.filter_policy == FilterPolicy.REPLACE # default even if not specified diff --git a/integrations/cohere/examples/cohere_embedding.py b/integrations/cohere/examples/cohere_embedding.py index 2b6c03016..e6fe3cc35 100644 --- a/integrations/cohere/examples/cohere_embedding.py +++ b/integrations/cohere/examples/cohere_embedding.py @@ -1,12 +1,8 @@ from haystack import Document, Pipeline from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.cohere.document_embedder import ( - CohereDocumentEmbedder, -) -from haystack_integrations.components.embedders.cohere.text_embedder import ( - CohereTextEmbedder, -) +from haystack_integrations.components.embedders.cohere.document_embedder import CohereDocumentEmbedder +from haystack_integrations.components.embedders.cohere.text_embedder import CohereTextEmbedder document_store = InMemoryDocumentStore(embedding_similarity_function="cosine") @@ -22,9 +18,7 @@ query_pipeline = Pipeline() query_pipeline.add_component("text_embedder", CohereTextEmbedder()) -query_pipeline.add_component( - "retriever", InMemoryEmbeddingRetriever(document_store=document_store) -) +query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store)) query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") query = "Who lives in Berlin?" diff --git a/integrations/cohere/examples/cohere_generation.py b/integrations/cohere/examples/cohere_generation.py index 44508e109..cd79e37d3 100644 --- a/integrations/cohere/examples/cohere_generation.py +++ b/integrations/cohere/examples/cohere_generation.py @@ -29,10 +29,7 @@ "properties": { "first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, "last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, - "nationality": { - "type": "string", - "enum": ["Italian", "Portuguese", "American"], - }, + "nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]}, }, "required": ["first_name", "last_name", "nationality"], } @@ -51,12 +48,6 @@ pipe.connect("fc_llm.replies", "validator.messages") pipe.connect("validator.validation_error", "joiner") -result = pipe.run( - data={ - "adapter": { - "chat_message": [ChatMessage.from_user("Create json from Peter Parker")] - } - } -) +result = pipe.run(data={"adapter": {"chat_message": [ChatMessage.from_user("Create json from Peter Parker")]}}) print(result["validator"]["validated"]) # noqa: T201 diff --git a/integrations/cohere/examples/cohere_ranker.py b/integrations/cohere/examples/cohere_ranker.py index 0c1c42356..79a3d346d 100644 --- a/integrations/cohere/examples/cohere_ranker.py +++ b/integrations/cohere/examples/cohere_ranker.py @@ -24,7 +24,5 @@ document_ranker_pipeline.connect("retriever.documents", "ranker.documents") query = "Cities in France" -res = document_ranker_pipeline.run( - data={"retriever": {"query": query}, "ranker": {"query": query, "top_k": 2}} -) +res = document_ranker_pipeline.run(data={"retriever": {"query": query}, "ranker": {"query": query, "top_k": 2}}) print(res["ranker"]["documents"]) # noqa: T201 diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index e0e3621ba..59a04cf3c 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -6,10 +6,7 @@ from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.components.embedders.cohere.utils import ( - get_async_response, - get_response, -) +from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response from cohere import AsyncClient, Client @@ -132,14 +129,10 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed: List[str] = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if doc.meta.get(key) is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if doc.meta.get(key) is not None ] - text_to_embed = self.embedding_separator.join( - meta_values_to_embed + [doc.content or ""] - ) # noqa: RUF005 + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 texts_to_embed.append(text_to_embed) return texts_to_embed @@ -153,11 +146,7 @@ def run(self, documents: List[Document]): - `meta`: metadata about the embedding process. :raises TypeError: if the input is not a list of `Documents`. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "CohereDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the CohereTextEmbedder." @@ -181,13 +170,7 @@ def run(self, documents: List[Document]): client_name="haystack", ) all_embeddings, metadata = asyncio.run( - get_async_response( - cohere_client, - texts_to_embed, - self.model, - self.input_type, - self.truncate, - ) + get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate) ) else: cohere_client = Client( diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index e8ce5c40e..80ede51bf 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -6,10 +6,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.components.embedders.cohere.utils import ( - get_async_response, - get_response, -) +from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response from cohere import AsyncClient, Client @@ -135,9 +132,7 @@ def run(self, text: str): client_name="haystack", ) embedding, metadata = asyncio.run( - get_async_response( - cohere_client, [text], self.model, self.input_type, self.truncate - ) + get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) ) else: cohere_client = Client( @@ -146,8 +141,6 @@ def run(self, text: str): timeout=self.timeout, client_name="haystack", ) - embedding, metadata = get_response( - cohere_client, [text], self.model, self.input_type, self.truncate - ) + embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) return {"embedding": embedding[0], "meta": metadata} diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index cdbe1d246..a5c20cb35 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -8,9 +8,7 @@ from cohere import AsyncClient, Client -async def get_async_response( - cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate -): +async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): """Embeds a list of texts asynchronously using the Cohere API. :param cohere_async_client: the Cohere `AsyncClient` @@ -27,9 +25,7 @@ async def get_async_response( all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} - response = await cohere_async_client.embed( - texts=texts, model=model_name, input_type=input_type, truncate=truncate - ) + response = await cohere_async_client.embed(texts=texts, model=model_name, input_type=input_type, truncate=truncate) if response.meta is not None: metadata = response.meta for emb in response.embeddings: @@ -39,13 +35,7 @@ async def get_async_response( def get_response( - cohere_client: Client, - texts: List[str], - model_name, - input_type, - truncate, - batch_size=32, - progress_bar=False, + cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False ) -> Tuple[List[List[float]], Dict[str, Any]]: """Embeds a list of texts using the Cohere API. @@ -72,9 +62,7 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed( - texts=batch, model=model_name, input_type=input_type, truncate=truncate - ) + response = cohere_client.embed(texts=batch, model=model_name, input_type=input_type, truncate=truncate) for emb in response.embeddings: all_embeddings.append(emb) if response.meta is not None: diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index a3dd40861..4ac59bf44 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -5,10 +5,7 @@ from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.callable_serialization import ( - deserialize_callable, - serialize_callable, -) +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable with LazyImport(message="Run 'pip install cohere'") as cohere_import: import cohere @@ -94,9 +91,7 @@ def __init__( self.generation_kwargs = generation_kwargs self.model_parameters = kwargs self.client = cohere.Client( - api_key=self.api_key.resolve_value(), - base_url=self.api_base_url, - client_name="haystack", + api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -112,11 +107,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = ( - serialize_callable(self.streaming_callback) - if self.streaming_callback - else None - ) + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, model=self.model, @@ -140,9 +131,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": deserialize_secrets_inplace(init_params, ["api_key"]) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable( - serialized_callback_handler - ) + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: @@ -151,11 +140,7 @@ def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return chat_message @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - generation_kwargs: Optional[Dict[str, Any]] = None, - ): + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -220,9 +205,7 @@ def _build_chunk(self, chunk) -> StreamingChunk: :param choice: The choice returned by the OpenAI API. :returns: The StreamingChunk. """ - chat_message = StreamingChunk( - content=chunk.text, meta={"event_type": chunk.event_type} - ) + chat_message = StreamingChunk(content=chunk.text, meta={"event_type": chunk.event_type}) return chat_message def _build_message(self, cohere_response): @@ -237,10 +220,7 @@ def _build_message(self, cohere_response): message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) elif cohere_response.text: message = ChatMessage.from_assistant(content=cohere_response.text) - total_tokens = ( - cohere_response.meta.billed_units.input_tokens - + cohere_response.meta.billed_units.output_tokens - ) + total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens message.meta.update( { "model": self.model, diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 1a5785b78..3cf4f8124 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -43,9 +43,7 @@ def __init__( """ # Note we have to call super() like this because of the way components are dynamically built with the decorator - super(CohereGenerator, self).__init__( - api_key, model, streaming_callback, api_base_url, None, **kwargs - ) # noqa + super(CohereGenerator, self).__init__(api_key, model, streaming_callback, api_base_url, None, **kwargs) # noqa @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) def run(self, prompt: str): @@ -60,7 +58,4 @@ def run(self, prompt: str): chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={}) # Note we have to call super() like this because of the way components are dynamically built with the decorator results = super(CohereGenerator, self).run([chat_message]) # noqa - return { - "replies": [results["replies"][0].content], - "meta": [results["replies"][0].meta], - } + return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py index 10ae622e7..7da823bbc 100644 --- a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -66,9 +66,7 @@ def __init__( self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_data_separator = meta_data_separator self._cohere_client = cohere.Client( - api_key=self.api_key.resolve_value(), - base_url=self.api_base_url, - client_name="haystack", + api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) def to_dict(self) -> Dict[str, Any]: @@ -112,13 +110,9 @@ def _prepare_cohere_input_docs(self, documents: List[Document]) -> List[str]: concatenated_input_list = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta.get(key) + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) ] - concatenated_input = self.meta_data_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) concatenated_input_list.append(concatenated_input) return concatenated_input_list diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index e4cb5a8bb..6521503f2 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -26,11 +26,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ - ChatMessage( - content="What's the capital of France", role=ChatRole.ASSISTANT, name=None - ) - ] + return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] class TestCohereChatGenerator: @@ -38,9 +34,7 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereChatGenerator() - assert component.api_key == Secret.from_env_var( - ["COHERE_API_KEY", "CO_API_KEY"] - ) + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.model == "command-r" assert component.streaming_callback is None assert component.api_base_url == "https://api.cohere.com" @@ -64,10 +58,7 @@ def test_init_with_parameters(self): assert component.model == "command-nightly" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") @@ -78,11 +69,7 @@ def test_to_dict_default(self, monkeypatch): "init_parameters": { "model": "command-r", "streaming_callback": None, - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "https://api.cohere.com", "generation_kwargs": {}, }, @@ -103,17 +90,10 @@ def test_to_dict_with_parameters(self, monkeypatch): "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command-nightly", - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "api_base_url": "test-base-url", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -131,16 +111,9 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): "init_parameters": { "model": "command-r", "api_base_url": "test-base-url", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "tests.test_cohere_chat_generator.", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -152,26 +125,16 @@ def test_from_dict(self, monkeypatch): "init_parameters": { "model": "command-r", "api_base_url": "test-base-url", - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } component = CohereChatGenerator.from_dict(data) assert component.model == "command-r" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) @@ -181,16 +144,9 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "init_parameters": { "model": "command-r", "api_base_url": "test-base-url", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } with pytest.raises(ValueError): @@ -199,25 +155,15 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): def test_message_to_dict(self, chat_messages): obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) dictionary = [obj._message_to_dict(message) for message in chat_messages] - assert dictionary == [ - {"user_name": "Chatbot", "text": "What's the capital of France"} - ] + assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ - ChatMessage( - content="What's the capital of France", - role=ChatRole.USER, - name="", - meta={}, - ) - ] + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -225,8 +171,7 @@ def test_live_run(self): assert "Paris" in message.content @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -236,8 +181,7 @@ def test_live_run_wrong_model(self, chat_messages): component.run(chat_messages) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -254,13 +198,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(streaming_callback=callback) results = component.run( - [ - ChatMessage( - content="What's the capital of France? answer in a word", - role=ChatRole.USER, - name=None, - ) - ] + [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] ) assert len(results["replies"]) == 1 @@ -273,24 +211,14 @@ def __call__(self, chunk: StreamingChunk) -> None: assert "Paris" in callback.responses @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ - ChatMessage( - content="What's the capital of France", - role=ChatRole.USER, - name="", - meta={}, - ) - ] + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) - results = component.run( - chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]} - ) + results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.content @@ -298,8 +226,7 @@ def test_live_run_with_connector(self): assert "citations" in message.meta # Citations might be None @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -314,17 +241,9 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - chat_messages = [ - ChatMessage( - content="What's the capital of France? answer in a word", - role=None, - name=None, - ) - ] + chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] component = CohereChatGenerator(streaming_callback=callback) - results = component.run( - chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]} - ) + results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -338,8 +257,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.meta["citations"] is not None @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -366,19 +284,11 @@ def test_tools_use(self): assert len(replies) > 0, "No replies received" first_reply = replies[0] - assert isinstance( - first_reply, ChatMessage - ), "First reply is not a ChatMessage instance" + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from( - first_reply, ChatRole.ASSISTANT - ), "First reply is not from the assistant" - assert ( - "get_stock_price" in first_reply.content.lower() - ), "First reply does not contain get_stock_price" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" assert first_reply.meta, "First reply has no metadata" fc_response = json.loads(first_reply.content) assert "name" in fc_response, "First reply does not contain name of the tool" - assert ( - "parameters" in fc_response - ), "First reply does not contain parameters of the tool" + assert "parameters" in fc_response, "First reply does not contain parameters of the tool" diff --git a/integrations/cohere/tests/test_cohere_generator.py b/integrations/cohere/tests/test_cohere_generator.py index d6b5e969a..736b6bfbf 100644 --- a/integrations/cohere/tests/test_cohere_generator.py +++ b/integrations/cohere/tests/test_cohere_generator.py @@ -17,9 +17,7 @@ class TestCohereGenerator: def test_init_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "foo") component = CohereGenerator() - assert component.api_key == Secret.from_env_var( - ["COHERE_API_KEY", "CO_API_KEY"] - ) + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.model == "command-r" assert component.streaming_callback is None assert component.api_base_url == COHERE_API_URL @@ -39,10 +37,7 @@ def test_init_with_parameters(self): assert component.model == "command-light" assert component.streaming_callback == callback assert component.api_base_url == "test-base-url" - assert component.model_parameters == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") @@ -52,11 +47,7 @@ def test_to_dict_default(self, monkeypatch): "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command-r", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": None, "api_base_url": COHERE_API_URL, "generation_kwargs": {}, @@ -80,11 +71,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "init_parameters": { "model": "command-light", "api_base_url": "test-base-url", - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {}, }, @@ -106,11 +93,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): "model": "command-r", "streaming_callback": "tests.test_cohere_generator.", "api_base_url": "test-base-url", - "api_key": { - "type": "env_var", - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - }, + "api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True}, "generation_kwargs": {}, }, } @@ -123,11 +106,7 @@ def test_from_dict(self, monkeypatch): "init_parameters": { "model": "command-r", "max_tokens": 10, - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "some_test_param": "test-params", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -138,14 +117,10 @@ def test_from_dict(self, monkeypatch): assert component.model == "command-r" assert component.streaming_callback == print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.model_parameters == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -158,8 +133,7 @@ def test_cohere_generator_run(self): assert results["meta"][0]["finish_reason"] == "COMPLETE" @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -169,8 +143,7 @@ def test_cohere_generator_run_wrong_model(self): component.run(prompt="What's the capital of France?") @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py index 849ae1554..670e662d4 100644 --- a/integrations/cohere/tests/test_cohere_ranker.py +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -20,6 +20,7 @@ def mock_ranker_response(): RerankResult] """ with patch("cohere.Client.rerank", autospec=True) as mock_ranker_response: + mock_response = Mock() mock_ranker_res_obj1 = Mock() @@ -41,9 +42,7 @@ def test_init_default(self, monkeypatch): component = CohereRanker() assert component.model_name == "rerank-english-v2.0" assert component.top_k == 10 - assert component.api_key == Secret.from_env_var( - ["COHERE_API_KEY", "CO_API_KEY"] - ) + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.api_base_url == COHERE_API_URL assert component.max_chunks_per_doc is None assert component.meta_fields_to_embed == [] @@ -52,10 +51,7 @@ def test_init_default(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("CO_API_KEY", raising=False) monkeypatch.delenv("COHERE_API_KEY", raising=False) - with pytest.raises( - ValueError, - match="None of the following authentication environment variables are set: *", - ): + with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): CohereRanker() def test_init_with_parameters(self, monkeypatch): @@ -71,9 +67,7 @@ def test_init_with_parameters(self, monkeypatch): ) assert component.model_name == "rerank-multilingual-v2.0" assert component.top_k == 5 - assert component.api_key == Secret.from_env_var( - ["COHERE_API_KEY", "CO_API_KEY"] - ) + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.api_base_url == "test-base-url" assert component.max_chunks_per_doc == 40 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] @@ -87,11 +81,7 @@ def test_to_dict_default(self, monkeypatch): "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", "init_parameters": { "model": "rerank-english-v2.0", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": COHERE_API_URL, "top_k": 10, "max_chunks_per_doc": None, @@ -116,11 +106,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", "init_parameters": { "model": "rerank-multilingual-v2.0", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "test-base-url", "top_k": 2, "max_chunks_per_doc": 50, @@ -135,11 +121,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", "init_parameters": { "model": "rerank-multilingual-v2.0", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "test-base-url", "top_k": 2, "max_chunks_per_doc": 50, @@ -150,9 +132,7 @@ def test_from_dict(self, monkeypatch): component = CohereRanker.from_dict(data) assert component.model_name == "rerank-multilingual-v2.0" assert component.top_k == 2 - assert component.api_key == Secret.from_env_var( - ["COHERE_API_KEY", "CO_API_KEY"] - ) + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.api_base_url == "test-base-url" assert component.max_chunks_per_doc == 50 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] @@ -165,19 +145,12 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", "init_parameters": { "model": "rerank-multilingual-v2.0", - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "top_k": 2, "max_chunks_per_doc": 50, }, } - with pytest.raises( - ValueError, - match="None of the following authentication environment variables are set: *", - ): + with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): CohereRanker.from_dict(data) def test_prepare_cohere_input_docs_default_separator(self, monkeypatch): @@ -207,10 +180,7 @@ def test_prepare_cohere_input_docs_default_separator(self, monkeypatch): def test_prepare_cohere_input_docs_custom_separator(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") - component = CohereRanker( - meta_fields_to_embed=["meta_field_1", "meta_field_2"], - meta_data_separator=" ", - ) + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") documents = [ Document( content=f"document number {i}", @@ -235,10 +205,7 @@ def test_prepare_cohere_input_docs_custom_separator(self, monkeypatch): def test_prepare_cohere_input_docs_no_meta_data(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") - component = CohereRanker( - meta_fields_to_embed=["meta_field_1", "meta_field_2"], - meta_data_separator=" ", - ) + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") documents = [Document(content=f"document number {i}") for i in range(5)] texts = component._prepare_cohere_input_docs(documents=documents) @@ -253,10 +220,7 @@ def test_prepare_cohere_input_docs_no_meta_data(self, monkeypatch): def test_prepare_cohere_input_docs_no_docs(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") - component = CohereRanker( - meta_fields_to_embed=["meta_field_1", "meta_field_2"], - meta_data_separator=" ", - ) + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") documents = [] texts = component._prepare_cohere_input_docs(documents=documents) @@ -267,11 +231,7 @@ def test_run_negative_topk_in_init(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker(top_k=-2) query = "test" - documents = [ - Document(content="doc1"), - Document(content="doc2"), - Document(content="doc3"), - ] + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] with pytest.raises(ValueError, match="top_k must be > 0, but got *"): ranker.run(query, documents) @@ -279,11 +239,7 @@ def test_run_zero_topk_in_init(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker(top_k=0) query = "test" - documents = [ - Document(content="doc1"), - Document(content="doc2"), - Document(content="doc3"), - ] + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] with pytest.raises(ValueError, match="top_k must be > 0, but got *"): ranker.run(query, documents) @@ -291,11 +247,7 @@ def test_run_negative_topk_in_run(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker() query = "test" - documents = [ - Document(content="doc1"), - Document(content="doc2"), - Document(content="doc3"), - ] + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] with pytest.raises(ValueError, match="top_k must be > 0, but got *"): ranker.run(query, documents, -3) @@ -303,17 +255,11 @@ def test_run_zero_topk_in_run_and_init(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker(top_k=0) query = "test" - documents = [ - Document(content="doc1"), - Document(content="doc2"), - Document(content="doc3"), - ] + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] with pytest.raises(ValueError, match="top_k must be > 0, but got *"): ranker.run(query, documents, 0) - def test_run_documents_provided( - self, monkeypatch, mock_ranker_response - ): # noqa: ARG002 + def test_run_documents_provided(self, monkeypatch, mock_ranker_response): # noqa: ARG002 monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker() query = "test" @@ -327,23 +273,11 @@ def test_run_documents_provided( assert isinstance(ranker_results, dict) reranked_docs = ranker_results["documents"] assert reranked_docs == [ - Document( - id="ijkl", - content="doc3", - meta={"meta_field": "meta_value_3"}, - score=0.98, - ), - Document( - id="efgh", - content="doc2", - meta={"meta_field": "meta_value_2"}, - score=0.95, - ), + Document(id="ijkl", content="doc3", meta={"meta_field": "meta_value_3"}, score=0.98), + Document(id="efgh", content="doc2", meta={"meta_field": "meta_value_2"}, score=0.95), ] - def test_run_topk_set_in_init( - self, monkeypatch, mock_ranker_response - ): # noqa: ARG002 + def test_run_topk_set_in_init(self, monkeypatch, mock_ranker_response): # noqa: ARG002 monkeypatch.setenv("CO_API_KEY", "test-api-key") ranker = CohereRanker(top_k=2) query = "test" @@ -363,8 +297,7 @@ def test_run_topk_set_in_init( ] @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -388,8 +321,7 @@ def test_live_run(self): assert set(result_documents_contents) == set(expected_documents_content) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index a6f699d02..ffbf280e9 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -59,11 +59,7 @@ def test_to_dict(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "model": "embed-english-v2.0", "input_type": "search_document", "api_base_url": COHERE_API_URL, @@ -95,11 +91,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "embed-multilingual-v2.0", "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", @@ -114,8 +106,7 @@ def test_to_dict_with_custom_init_parameters(self): } @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration @@ -124,10 +115,7 @@ def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] result = embedder.run(docs) diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index ded1a8c6f..b4f3e234c 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -56,11 +56,7 @@ def test_to_dict(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "model": "embed-english-v2.0", "input_type": "search_query", "api_base_url": COHERE_API_URL, @@ -87,11 +83,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "embed-multilingual-v2.0", "input_type": "classification", "api_base_url": "https://custom-api-base-url.com", @@ -112,8 +104,7 @@ def test_run_wrong_input_format(self): embedder.run(text=list_integers_input) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None) - and not os.environ.get("CO_API_KEY", None), + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration diff --git a/integrations/deepeval/example/example.py b/integrations/deepeval/example/example.py index df136014f..e1265a739 100644 --- a/integrations/deepeval/example/example.py +++ b/integrations/deepeval/example/example.py @@ -1,10 +1,7 @@ # A valid OpenAI API key is required to run this example. from haystack import Pipeline -from haystack_integrations.components.evaluators.deepeval import ( - DeepEvalEvaluator, - DeepEvalMetric, -) +from haystack_integrations.components.evaluators.deepeval import DeepEvalEvaluator, DeepEvalMetric QUESTIONS = [ "Which is the most popular global sport?", @@ -34,15 +31,7 @@ # Each metric expects a specific set of parameters as input. Refer to the # DeepEvalMetric class' documentation for more details. -results = pipeline.run( - { - "evaluator": { - "questions": QUESTIONS, - "contexts": CONTEXTS, - "responses": RESPONSES, - } - } -) +results = pipeline.run({"evaluator": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES}}) for output in results["evaluator"]["results"]: print(output) diff --git a/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/evaluator.py b/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/evaluator.py index bfd0259b9..082ae15fd 100644 --- a/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/evaluator.py +++ b/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/evaluator.py @@ -63,11 +63,7 @@ def __init__( Refer to the `RagasMetric` class for more details on required parameters. """ - self.metric = ( - metric - if isinstance(metric, DeepEvalMetric) - else DeepEvalMetric.from_str(metric) - ) + self.metric = metric if isinstance(metric, DeepEvalMetric) else DeepEvalMetric.from_str(metric) self.metric_params = metric_params self.descriptor = METRIC_DESCRIPTORS[self.metric] @@ -93,16 +89,11 @@ def run(self, **inputs) -> Dict[str, Any]: - `score` - The score of the metric. - `explanation` - An optional explanation of the score. """ - InputConverters.validate_input_parameters( - self.metric, self.descriptor.input_parameters, inputs - ) + InputConverters.validate_input_parameters(self.metric, self.descriptor.input_parameters, inputs) converted_inputs: List[LLMTestCase] = list(self.descriptor.input_converter(**inputs)) # type: ignore results = self._backend_callable(converted_inputs, self._backend_metric) - converted_results = [ - [result.to_dict() for result in self.descriptor.output_converter(x)] - for x in results - ] + converted_results = [[result.to_dict() for result in self.descriptor.output_converter(x)] for x in results] return {"results": converted_results} @@ -146,9 +137,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "DeepEvalEvaluator": return default_from_dict(cls, data) @staticmethod - def _invoke_deepeval( - test_cases: List[LLMTestCase], metric: BaseMetric - ) -> List[TestResult]: + def _invoke_deepeval(test_cases: List[LLMTestCase], metric: BaseMetric) -> List[TestResult]: return evaluate(test_cases, [metric]) def _init_backend(self): @@ -159,18 +148,14 @@ def _init_backend(self): if self.metric_params is None: msg = f"DeepEval metric '{self.metric}' expected init parameters but got none" raise ValueError(msg) - elif not all( - k in self.descriptor.init_parameters for k in self.metric_params.keys() - ): + elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()): msg = ( f"Invalid init parameters for DeepEval metric '{self.metric}'. " f"Expected: {list(self.descriptor.init_parameters.keys())}" ) raise ValueError(msg) - backend_metric_params = ( - dict(self.metric_params) if self.metric_params is not None else {} - ) + backend_metric_params = dict(self.metric_params) if self.metric_params is not None else {} # This shouldn't matter at all as we aren't asserting the outputs, but just in case... backend_metric_params["threshold"] = 0.0 diff --git a/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/metrics.py b/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/metrics.py index 629420021..7fb5db5b0 100644 --- a/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/metrics.py +++ b/integrations/deepeval/src/haystack_integrations/components/evaluators/deepeval/metrics.py @@ -133,10 +133,7 @@ def new( for name, param in input_converter_signature.parameters.items(): if name in ("cls", "self"): continue - elif param.kind not in ( - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): + elif param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): continue input_parameters[name] = param.annotation @@ -145,9 +142,7 @@ def new( backend=backend, input_parameters=input_parameters, input_converter=input_converter, - output_converter=output_converter - if output_converter is not None - else OutputConverters.default(metric), + output_converter=output_converter if output_converter is not None else OutputConverters.default(metric), init_parameters=init_parameters, ) @@ -170,9 +165,7 @@ def _validate_input_elements(**kwargs): f"got '{type(collection).__name__}' instead" ) raise ValueError(msg) - elif not all(isinstance(x, str) for x in collection) and not all( - isinstance(x, list) for x in collection - ): + elif not all(isinstance(x, str) for x in collection) and not all(isinstance(x, list) for x in collection): msg = f"DeepEval evaluator expects inputs to be of type 'str' or 'list' in '{k}'" raise ValueError(msg) @@ -182,9 +175,7 @@ def _validate_input_elements(**kwargs): raise ValueError(msg) @staticmethod - def validate_input_parameters( - metric: DeepEvalMetric, expected: Dict[str, Any], received: Dict[str, Any] - ): + def validate_input_parameters(metric: DeepEvalMetric, expected: Dict[str, Any], received: Dict[str, Any]): for param, _ in expected.items(): if param not in received: msg = f"DeepEval evaluator expected input parameter '{param}' for metric '{metric}'" @@ -194,27 +185,18 @@ def validate_input_parameters( def question_context_response( questions: List[str], contexts: List[List[str]], responses: List[str] ) -> Iterable[LLMTestCase]: - InputConverters._validate_input_elements( - questions=questions, contexts=contexts, responses=responses - ) + InputConverters._validate_input_elements(questions=questions, contexts=contexts, responses=responses) for q, c, r in zip(questions, contexts, responses): # type: ignore test_case = LLMTestCase(input=q, actual_output=r, retrieval_context=c) yield test_case @staticmethod def question_context_response_ground_truth( - questions: List[str], - contexts: List[List[str]], - responses: List[str], - ground_truths: List[str], + questions: List[str], contexts: List[List[str]], responses: List[str], ground_truths: List[str] ) -> Iterable[LLMTestCase]: - InputConverters._validate_input_elements( - questions=questions, contexts=contexts, responses=responses - ) + InputConverters._validate_input_elements(questions=questions, contexts=contexts, responses=responses) for q, c, r, gt in zip(questions, contexts, responses, ground_truths): # type: ignore - test_case = LLMTestCase( - input=q, actual_output=r, retrieval_context=c, expected_output=gt - ) + test_case = LLMTestCase(input=q, actual_output=r, retrieval_context=c, expected_output=gt) yield test_case @@ -233,13 +215,7 @@ def inner(output: TestResult, metric: DeepEvalMetric) -> List[MetricResult]: metric_name = str(metric) assert len(output.metrics) == 1 metric_result = output.metrics[0] - out = [ - MetricResult( - name=metric_name, - score=metric_result.score, - explanation=metric_result.reason, - ) - ] + out = [MetricResult(name=metric_name, score=metric_result.score, explanation=metric_result.reason)] if metric_result.score_breakdown is not None: for k, v in metric_result.score_breakdown.items(): out.append(MetricResult(name=f"{metric_name}_{k}", score=v)) diff --git a/integrations/deepeval/tests/test_evaluator.py b/integrations/deepeval/tests/test_evaluator.py index f58036603..7d1946185 100644 --- a/integrations/deepeval/tests/test_evaluator.py +++ b/integrations/deepeval/tests/test_evaluator.py @@ -7,10 +7,7 @@ import pytest from haystack import DeserializationError -from haystack_integrations.components.evaluators.deepeval import ( - DeepEvalEvaluator, - DeepEvalMetric, -) +from haystack_integrations.components.evaluators.deepeval import DeepEvalEvaluator, DeepEvalMetric from deepeval.evaluate import TestResult, BaseMetric DEFAULT_QUESTIONS = [ @@ -65,15 +62,7 @@ def eval(self, test_cases, metric): out = [] for x in test_cases: - r = TestResult( - False, - [], - x.input, - x.actual_output, - x.expected_output, - x.context, - x.retrieval_context, - ) + r = TestResult(False, [], x.input, x.actual_output, x.expected_output, x.context, x.retrieval_context) r.metrics = copy.deepcopy(output_map[self.metric]) out.append(r) return out @@ -82,15 +71,11 @@ def eval(self, test_cases, metric): def test_evaluator_metric_init_params(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - eval = DeepEvalEvaluator( - DeepEvalMetric.ANSWER_RELEVANCY, metric_params={"model": "gpt-4-32k"} - ) + eval = DeepEvalEvaluator(DeepEvalMetric.ANSWER_RELEVANCY, metric_params={"model": "gpt-4-32k"}) assert eval._backend_metric.evaluation_model == "gpt-4-32k" with pytest.raises(ValueError, match="Invalid init parameters"): - DeepEvalEvaluator( - DeepEvalMetric.FAITHFULNESS, metric_params={"role": "village idiot"} - ) + DeepEvalEvaluator(DeepEvalMetric.FAITHFULNESS, metric_params={"role": "village idiot"}) with pytest.raises(ValueError, match="expected init parameters"): DeepEvalEvaluator(DeepEvalMetric.CONTEXTUAL_RECALL) @@ -111,9 +96,7 @@ def test_evaluator_serde(monkeypatch): assert eval.metric_params == new_eval.metric_params assert type(new_eval._backend_metric) == type(eval._backend_metric) - with pytest.raises( - DeserializationError, match=r"cannot serialize the metric parameters" - ): + with pytest.raises(DeserializationError, match=r"cannot serialize the metric parameters"): eval.metric_params["model"] = Unserializable("") eval.to_dict() @@ -221,21 +204,13 @@ def test_evaluator_invalid_inputs(metric, inputs, error_string, params, monkeypa [ ( DeepEvalMetric.ANSWER_RELEVANCY, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, [[(None, 0.5, "1")]] * 2, {"model": "gpt-4"}, ), ( DeepEvalMetric.FAITHFULNESS, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, [[(None, 0.1, "2")]] * 2, {"model": "gpt-4"}, ), @@ -263,19 +238,13 @@ def test_evaluator_invalid_inputs(metric, inputs, error_string, params, monkeypa ), ( DeepEvalMetric.CONTEXTUAL_RELEVANCE, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, [[(None, 1.5, "5")]] * 2, {"model": "gpt-4"}, ), ], ) -def test_evaluator_outputs( - metric, inputs, expected_outputs, metric_params, monkeypatch -): +def test_evaluator_outputs(metric, inputs, expected_outputs, metric_params, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") init_params = { @@ -283,9 +252,7 @@ def test_evaluator_outputs( "metric_params": metric_params, } eval = DeepEvalEvaluator(**init_params) - eval._backend_callable = lambda testcases, metrics: MockBackend(metric).eval( - testcases, metrics - ) + eval._backend_callable = lambda testcases, metrics: MockBackend(metric).eval(testcases, metrics) results = eval.run(**inputs)["results"] assert type(results) == type(expected_outputs) @@ -294,10 +261,7 @@ def test_evaluator_outputs( for r, o in zip(results, expected_outputs): assert len(r) == len(o) - expected = { - (name if name is not None else str(metric), score, exp) - for name, score, exp in o - } + expected = {(name if name is not None else str(metric), score, exp) for name, score, exp in o} got = {(x["name"], x["score"], x["explanation"]) for x in r} assert got == expected @@ -312,20 +276,12 @@ def test_evaluator_outputs( [ ( DeepEvalMetric.ANSWER_RELEVANCY, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, {"model": "gpt-4"}, ), ( DeepEvalMetric.FAITHFULNESS, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, {"model": "gpt-4"}, ), ( @@ -350,11 +306,7 @@ def test_evaluator_outputs( ), ( DeepEvalMetric.CONTEXTUAL_RELEVANCE, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, {"model": "gpt-4"}, ), ], diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py index a35d16a9a..f273c955b 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py @@ -7,9 +7,7 @@ from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.elasticsearch.document_store import ( - ElasticsearchDocumentStore, -) +from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore @component @@ -78,11 +76,7 @@ def __init__( self._fuzziness = fuzziness self._top_k = top_k self._scale_score = scale_score - self._filter_policy = ( - FilterPolicy.from_str(filter_policy) - if isinstance(filter_policy, str) - else filter_policy - ) + self._filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy def to_dict(self) -> Dict[str, Any]: """ @@ -111,26 +105,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchBM25Retriever": :returns: Deserialized component. """ - data["init_parameters"][ - "document_store" - ] = ElasticsearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run( - self, - query: str, - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ): + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieve documents using the BM25 keyword-based algorithm. diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py index c66e7c2f2..10e860ea4 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py @@ -7,9 +7,7 @@ from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.elasticsearch.document_store import ( - ElasticsearchDocumentStore, -) +from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore @component @@ -77,11 +75,7 @@ def __init__( self._filters = filters or {} self._top_k = top_k self._num_candidates = num_candidates - self._filter_policy = ( - FilterPolicy.from_str(filter_policy) - if isinstance(filter_policy, str) - else filter_policy - ) + self._filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy def to_dict(self) -> Dict[str, Any]: """ @@ -109,26 +103,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchEmbeddingRetriever": :returns: Deserialized component. """ - data["init_parameters"][ - "document_store" - ] = ElasticsearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run( - self, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ): + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieve documents using a vector similarity metric. diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 65266b0e2..11016e3fc 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -65,9 +65,7 @@ def __init__( hosts: Optional[Hosts] = None, custom_mapping: Optional[Dict[str, Any]] = None, index: str = "default", - embedding_similarity_function: Literal[ - "cosine", "dot_product", "l2_norm", "max_inner_product" - ] = "cosine", + embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine", **kwargs, ): """ @@ -207,9 +205,7 @@ def _search_documents(self, **kwargs) -> List[Document]: **kwargs, ) - documents.extend( - self._deserialize_document(hit) for hit in res["hits"]["hits"] - ) + documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) from_ = len(documents) if top_k is not None and from_ >= top_k: @@ -218,9 +214,7 @@ def _search_documents(self, **kwargs) -> List[Document]: break return documents - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ The main query method for the document store. It retrieves all documents that match the filters. @@ -236,9 +230,7 @@ def filter_documents( documents = self._search_documents(query=query) return documents - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes `Document`s to Elasticsearch. @@ -252,9 +244,7 @@ def write_documents( """ if len(documents) > 0: if not isinstance(documents[0], Document): - msg = ( - "param 'documents' must contain a list of objects of type Document" - ) + msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(msg) if policy == DuplicatePolicy.NONE: @@ -295,15 +285,9 @@ def write_documents( other_errors = [] for e in errors: error_type = e["create"]["error"]["type"] - if ( - policy == DuplicatePolicy.FAIL - and error_type == "version_conflict_engine_exception" - ): + if policy == DuplicatePolicy.FAIL and error_type == "version_conflict_engine_exception": duplicate_errors_ids.append(e["create"]["_id"]) - elif ( - policy == DuplicatePolicy.SKIP - and error_type == "version_conflict_engine_exception" - ): + elif policy == DuplicatePolicy.SKIP and error_type == "version_conflict_engine_exception": # when the policy is skip, duplication errors are OK and we should not raise an exception continue else: @@ -412,9 +396,7 @@ def _bm25_retrieval( if scale_score: for doc in documents: - doc.score = float( - 1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR))) - ) + doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR)))) return documents diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/filters.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/filters.py index 80776dd65..b5adc37db 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/filters.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/filters.py @@ -53,9 +53,7 @@ def _equal(field: str, value: Any) -> Dict[str, Any]: "terms_set": { field: { "terms": value, - "minimum_should_match_script": { - "source": f"Math.max(params.num_terms, doc['{field}'].size())" - }, + "minimum_should_match_script": {"source": f"Math.max(params.num_terms, doc['{field}'].size())"}, } } } @@ -73,13 +71,7 @@ def _not_equal(field: str, value: Any) -> Dict[str, Any]: return {"bool": {"must_not": {"terms": {field: value}}}} if field in ["text", "dataframe"]: # We want to fully match the text field. - return { - "bool": { - "must_not": { - "match": {field: {"query": value, "minimum_should_match": "100%"}} - } - } - } + return {"bool": {"must_not": {"match": {field: {"query": value, "minimum_should_match": "100%"}}}}} return {"bool": {"must_not": {"term": {field: value}}}} @@ -90,14 +82,7 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -119,14 +104,7 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -148,14 +126,7 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -177,14 +148,7 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -271,9 +235,7 @@ def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ] ``` """ - range_conditions = [ - next(iter(c["range"].items())) for c in conditions if "range" in c - ] + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] if range_conditions: conditions = [c for c in conditions if "range" not in c] range_conditions_dict: Dict[str, Any] = {} diff --git a/integrations/elasticsearch/tests/test_bm25_retriever.py b/integrations/elasticsearch/tests/test_bm25_retriever.py index 467b34faa..3e9ebc9b8 100644 --- a/integrations/elasticsearch/tests/test_bm25_retriever.py +++ b/integrations/elasticsearch/tests/test_bm25_retriever.py @@ -6,12 +6,8 @@ import pytest from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from haystack_integrations.components.retrievers.elasticsearch import ( - ElasticsearchBM25Retriever, -) -from haystack_integrations.document_stores.elasticsearch import ( - ElasticsearchDocumentStore, -) +from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchBM25Retriever +from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore def test_init_default(): @@ -23,18 +19,14 @@ def test_init_default(): assert retriever._filter_policy == FilterPolicy.REPLACE assert not retriever._scale_score - retriever = ElasticsearchBM25Retriever( - document_store=mock_store, filter_policy="replace" - ) + retriever = ElasticsearchBM25Retriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): ElasticsearchBM25Retriever(document_store=mock_store, filter_policy="keep") -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some fake host") retriever = ElasticsearchBM25Retriever(document_store=document_store) @@ -60,9 +52,7 @@ def test_to_dict(_mock_elasticsearch_client): } -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict(_mock_elasticsearch_client): data = { "type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever", @@ -87,9 +77,7 @@ def test_from_dict(_mock_elasticsearch_client): assert retriever._filter_policy == FilterPolicy.REPLACE -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict_no_filter_policy(_mock_elasticsearch_client): data = { "type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever", diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 421c690fb..20b68f126 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -12,22 +12,16 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests -from haystack_integrations.document_stores.elasticsearch import ( - ElasticsearchDocumentStore, -) +from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_init_is_lazy(_mock_es_client): ElasticsearchDocumentStore(hosts="testhost") _mock_es_client.assert_not_called() -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some hosts") res = document_store.to_dict() @@ -42,9 +36,7 @@ def test_to_dict(_mock_elasticsearch_client): } -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict(_mock_elasticsearch_client): data = { "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", @@ -84,16 +76,12 @@ def document_store(self, request): embedding_similarity_function = "max_inner_product" store = ElasticsearchDocumentStore( - hosts=hosts, - index=index, - embedding_similarity_function=embedding_similarity_function, + hosts=hosts, index=index, embedding_similarity_function=embedding_similarity_function ) yield store store.client.options(ignore_status=[400, 404]).indices.delete(index=index) - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ The ElasticSearchDocumentStore.filter_documents() method returns a Documents with their score set. We don't want to compare the score, so we set it to None before comparing the documents. @@ -119,9 +107,7 @@ def assert_documents_are_equal( super().assert_documents_are_equal(received, expected) def test_user_agent_header(self, document_store: ElasticsearchDocumentStore): - assert document_store.client._headers["user-agent"].startswith( - "haystack-py-ds/" - ) + assert document_store.client._headers["user-agent"].startswith("haystack-py-ds/") def test_write_documents(self, document_store: ElasticsearchDocumentStore): docs = [Document(id="1")] @@ -152,9 +138,7 @@ def test_bm25_retrieval(self, document_store: ElasticsearchDocumentStore): assert "functional" in res[1].content assert "functional" in res[2].content - def test_bm25_retrieval_pagination( - self, document_store: ElasticsearchDocumentStore - ): + def test_bm25_retrieval_pagination(self, document_store: ElasticsearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. """ @@ -182,9 +166,7 @@ def test_bm25_retrieval_pagination( assert len(res) == 11 assert all("programming" in doc.content for doc in res) - def test_bm25_retrieval_with_fuzziness( - self, document_store: ElasticsearchDocumentStore - ): + def test_bm25_retrieval_with_fuzziness(self, document_store: ElasticsearchDocumentStore): document_store.write_documents( [ Document(content="Haskell is a functional programming language"), @@ -214,17 +196,12 @@ def test_bm25_retrieval_with_fuzziness( assert "functional" in res[1].content assert "functional" in res[2].content - def test_bm25_not_all_terms_must_match( - self, document_store: ElasticsearchDocumentStore - ): + def test_bm25_not_all_terms_must_match(self, document_store: ElasticsearchDocumentStore): """ Test that not all terms must mandatorily match for BM25 retrieval to return a result. """ documents = [ - Document( - id=1, - content="There are over 7,000 languages spoken around the world today.", - ), + Document(id=1, content="There are over 7,000 languages spoken around the world today."), Document( id=2, content=( @@ -242,9 +219,7 @@ def test_bm25_not_all_terms_must_match( ] document_store.write_documents(documents) - res = document_store._bm25_retrieval( - "How much self awareness do elephants have?", top_k=3 - ) + res = document_store._bm25_retrieval("How much self awareness do elephants have?", top_k=3) assert len(res) == 1 assert res[0].id == 2 @@ -252,21 +227,15 @@ def test_embedding_retrieval(self, document_store: ElasticsearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] - ), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), ] document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Most similar document" assert results[1].content == "2nd best document" - def test_embedding_retrieval_with_filters( - self, document_store: ElasticsearchDocumentStore - ): + def test_embedding_retrieval_with_filters(self, document_store: ElasticsearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), @@ -279,30 +248,22 @@ def test_embedding_retrieval_with_filters( document_store.write_documents(docs) filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} - results = document_store._embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters=filters - ) + results = document_store._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters=filters) assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" - def test_embedding_retrieval_pagination( - self, document_store: ElasticsearchDocumentStore - ): + def test_embedding_retrieval_pagination(self, document_store: ElasticsearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. """ docs = [ - Document( - content=f"Document {i}", embedding=[random.random() for _ in range(4)] - ) # noqa: S311 + Document(content=f"Document {i}", embedding=[random.random() for _ in range(4)]) # noqa: S311 for i in range(20) ] document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=11, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=11, filters={}) assert len(results) == 11 def test_embedding_retrieval_query_documents_different_embedding_sizes( @@ -317,9 +278,7 @@ def test_embedding_retrieval_query_documents_different_embedding_sizes( with pytest.raises(BadRequestError): document_store._embedding_retrieval(query_embedding=[0.1, 0.1]) - def test_write_documents_different_embedding_sizes_fail( - self, document_store: ElasticsearchDocumentStore - ): + def test_write_documents_different_embedding_sizes_fail(self, document_store: ElasticsearchDocumentStore): """ Test that write_documents fails if the documents have different embedding sizes. """ @@ -331,17 +290,11 @@ def test_write_documents_different_embedding_sizes_fail( with pytest.raises(DocumentStoreError): document_store.write_documents(docs) - @patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" - ) + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_init_with_custom_mapping(self, mock_elasticsearch): custom_mapping = { "properties": { - "embedding": { - "type": "dense_vector", - "index": True, - "similarity": "dot_product", - }, + "embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"}, "content": {"type": "text"}, }, "dynamic_templates": [ @@ -361,9 +314,7 @@ def test_init_with_custom_mapping(self, mock_elasticsearch): ) mock_elasticsearch.return_value = mock_client - _ = ElasticsearchDocumentStore( - hosts="some hosts", custom_mapping=custom_mapping - ).client + _ = ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping).client mock_client.indices.create.assert_called_once_with( index="default", mappings=custom_mapping, diff --git a/integrations/elasticsearch/tests/test_embedding_retriever.py b/integrations/elasticsearch/tests/test_embedding_retriever.py index c42094ff0..2d03f0ec2 100644 --- a/integrations/elasticsearch/tests/test_embedding_retriever.py +++ b/integrations/elasticsearch/tests/test_embedding_retriever.py @@ -6,12 +6,8 @@ import pytest from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from haystack_integrations.components.retrievers.elasticsearch import ( - ElasticsearchEmbeddingRetriever, -) -from haystack_integrations.document_stores.elasticsearch import ( - ElasticsearchDocumentStore, -) +from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchEmbeddingRetriever +from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore def test_init_default(): @@ -22,18 +18,14 @@ def test_init_default(): assert retriever._top_k == 10 assert retriever._num_candidates is None - retriever = ElasticsearchEmbeddingRetriever( - document_store=mock_store, filter_policy="replace" - ) + retriever = ElasticsearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): ElasticsearchEmbeddingRetriever(document_store=mock_store, filter_policy="keep") -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some fake host") retriever = ElasticsearchEmbeddingRetriever(document_store=document_store) @@ -59,9 +51,7 @@ def test_to_dict(_mock_elasticsearch_client): } -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict(_mock_elasticsearch_client): t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever" data = { @@ -84,9 +74,7 @@ def test_from_dict(_mock_elasticsearch_client): assert retriever._num_candidates is None -@patch( - "haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch" -) +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict_no_filter_policy(_mock_elasticsearch_client): t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever" data = { @@ -111,9 +99,7 @@ def test_from_dict_no_filter_policy(_mock_elasticsearch_client): def test_run(): mock_store = Mock(spec=ElasticsearchDocumentStore) - mock_store._embedding_retrieval.return_value = [ - Document(content="Test doc", embedding=[0.1, 0.2]) - ] + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = ElasticsearchEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( diff --git a/integrations/elasticsearch/tests/test_filters.py b/integrations/elasticsearch/tests/test_filters.py index c53974613..86e5cba74 100644 --- a/integrations/elasticsearch/tests/test_filters.py +++ b/integrations/elasticsearch/tests/test_filters.py @@ -3,10 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest from haystack.errors import FilterError -from haystack_integrations.document_stores.elasticsearch.filters import ( - _normalize_filters, - _normalize_ranges, -) +from haystack_integrations.document_stores.elasticsearch.filters import _normalize_filters, _normalize_ranges filters_data = [ ( @@ -17,16 +14,8 @@ { "operator": "OR", "conditions": [ - { - "field": "meta.genre", - "operator": "in", - "value": ["economy", "politics"], - }, - { - "field": "meta.publisher", - "operator": "==", - "value": "nytimes", - }, + {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, ], }, {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, @@ -75,22 +64,8 @@ { "bool": { "should": [ - { - "bool": { - "must": [ - {"term": {"Type": "News Paper"}}, - {"range": {"Date": {"lt": "2020-01-01"}}}, - ] - } - }, - { - "bool": { - "must": [ - {"term": {"Type": "Blog Post"}}, - {"range": {"Date": {"gte": "2019-01-01"}}}, - ] - } - }, + {"bool": {"must": [{"term": {"Type": "News Paper"}}, {"range": {"Date": {"lt": "2020-01-01"}}}]}}, + {"bool": {"must": [{"term": {"Type": "Blog Post"}}, {"range": {"Date": {"gte": "2019-01-01"}}}]}}, ] } }, @@ -106,16 +81,8 @@ { "operator": "OR", "conditions": [ - { - "field": "meta.genre", - "operator": "in", - "value": ["economy", "politics"], - }, - { - "field": "meta.publisher", - "operator": "==", - "value": "nytimes", - }, + {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, ], }, ], @@ -139,26 +106,8 @@ }, ), ( - { - "operator": "AND", - "conditions": [ - {"field": "text", "operator": "==", "value": "A Foo Document 1"} - ], - }, - { - "bool": { - "must": [ - { - "match": { - "text": { - "query": "A Foo Document 1", - "minimum_should_match": "100%", - } - } - } - ] - } - }, + {"operator": "AND", "conditions": [{"field": "text", "operator": "==", "value": "A Foo Document 1"}]}, + {"bool": {"must": [{"match": {"text": {"query": "A Foo Document 1", "minimum_should_match": "100%"}}}]}}, ), ( { @@ -177,14 +126,7 @@ { "bool": { "should": [ - { - "bool": { - "should": [ - {"term": {"name": "name_0"}}, - {"term": {"name": "name_1"}}, - ] - } - }, + {"bool": {"should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}]}}, {"range": {"number": {"lt": 1.0}}}, ] } @@ -199,14 +141,7 @@ {"field": "meta.name", "operator": "in", "value": ["name_0", "name_1"]}, ], }, - { - "bool": { - "must": [ - {"terms": {"name": ["name_0", "name_1"]}}, - {"range": {"number": {"lte": 2, "gte": 0}}}, - ] - } - }, + {"bool": {"must": [{"terms": {"name": ["name_0", "name_1"]}}, {"range": {"number": {"lte": 2, "gte": 0}}}]}}, ), ( { @@ -226,11 +161,7 @@ {"field": "meta.name", "operator": "==", "value": "name_1"}, ], }, - { - "bool": { - "should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}] - } - }, + {"bool": {"should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}]}}, ), ( { @@ -240,20 +171,7 @@ {"field": "meta.name", "operator": "==", "value": "name_0"}, ], }, - { - "bool": { - "must_not": [ - { - "bool": { - "must": [ - {"term": {"number": 100}}, - {"term": {"name": "name_0"}}, - ] - } - } - ] - } - }, + {"bool": {"must_not": [{"bool": {"must": [{"term": {"number": 100}}, {"term": {"name": "name_0"}}]}}]}}, ), ] @@ -280,27 +198,15 @@ def test_normalize_filters_malformed(): # Missing comparison field with pytest.raises(FilterError): - _normalize_filters( - {"operator": "AND", "conditions": [{"operator": "==", "value": "article"}]} - ) + _normalize_filters({"operator": "AND", "conditions": [{"operator": "==", "value": "article"}]}) # Missing comparison operator with pytest.raises(FilterError): - _normalize_filters( - { - "operator": "AND", - "conditions": [{"field": "meta.type", "operator": "=="}], - } - ) + _normalize_filters({"operator": "AND", "conditions": [{"field": "meta.type", "operator": "=="}]}) # Missing comparison value with pytest.raises(FilterError): - _normalize_filters( - { - "operator": "AND", - "conditions": [{"field": "meta.type", "value": "article"}], - } - ) + _normalize_filters({"operator": "AND", "conditions": [{"field": "meta.type", "value": "article"}]}) def test_normalize_ranges(): diff --git a/integrations/fastembed/examples/example.py b/integrations/fastembed/examples/example.py index 7eefe2e20..e4d328210 100644 --- a/integrations/fastembed/examples/example.py +++ b/integrations/fastembed/examples/example.py @@ -1,10 +1,7 @@ from haystack import Document, Pipeline from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.fastembed import ( - FastembedDocumentEmbedder, - FastembedTextEmbedder, -) +from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder, FastembedTextEmbedder document_store = InMemoryDocumentStore(embedding_similarity_function="cosine") @@ -22,9 +19,7 @@ query_pipeline = Pipeline() query_pipeline.add_component("text_embedder", FastembedTextEmbedder()) -query_pipeline.add_component( - "retriever", InMemoryEmbeddingRetriever(document_store=document_store) -) +query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store)) query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") query = "Who supports fastembed?" diff --git a/integrations/fastembed/examples/sparse_example.py b/integrations/fastembed/examples/sparse_example.py index 254180307..bce3b363d 100644 --- a/integrations/fastembed/examples/sparse_example.py +++ b/integrations/fastembed/examples/sparse_example.py @@ -4,9 +4,7 @@ # involving indexing and retrieval of documents. from haystack import Document -from haystack_integrations.components.embedders.fastembed import ( - FastembedSparseDocumentEmbedder, -) +from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder document_list = [ Document( diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 2050cf30f..66f797549 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -27,14 +27,9 @@ def get_embedding_backend( return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedEmbeddingBackend( - model_name=model_name, - cache_dir=cache_dir, - threads=threads, - local_files_only=local_files_only, + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only ) - _FastembedEmbeddingBackendFactory._instances[ - embedding_backend_id - ] = embedding_backend + _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -51,10 +46,7 @@ def __init__( local_files_only: bool = False, ): self.model = TextEmbedding( - model_name=model_name, - cache_dir=cache_dir, - threads=threads, - local_files_only=local_files_only, + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only ) def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]: @@ -62,10 +54,7 @@ def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float embeddings = [] embeddings_iterable = self.model.embed(data, **kwargs) for np_array in tqdm( - embeddings_iterable, - disable=not progress_bar, - desc="Calculating embeddings", - total=len(data), + embeddings_iterable, disable=not progress_bar, desc="Calculating embeddings", total=len(data) ): embeddings.append(np_array.tolist()) return embeddings @@ -88,19 +77,12 @@ def get_embedding_backend( embedding_backend_id = f"{model_name}{cache_dir}{threads}" if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances: - return _FastembedSparseEmbeddingBackendFactory._instances[ - embedding_backend_id - ] + return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, - cache_dir=cache_dir, - threads=threads, - local_files_only=local_files_only, + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only ) - _FastembedSparseEmbeddingBackendFactory._instances[ - embedding_backend_id - ] = embedding_backend + _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -117,15 +99,10 @@ def __init__( local_files_only: bool = False, ): self.model = SparseTextEmbedding( - model_name=model_name, - cache_dir=cache_dir, - threads=threads, - local_files_only=local_files_only, + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only ) - def embed( - self, data: List[List[str]], progress_bar=True, **kwargs - ) -> List[SparseEmbedding]: + def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: # The embed method returns a Iterable[SparseEmbedding], so we convert to Haystack SparseEmbedding type. # Each SparseEmbedding contains an `indices` key containing a list of int and # an `values` key containing a list of floats. @@ -133,16 +110,10 @@ def embed( sparse_embeddings = [] sparse_embeddings_iterable = self.model.embed(data, **kwargs) for sparse_embedding in tqdm( - sparse_embeddings_iterable, - disable=not progress_bar, - desc="Calculating sparse embeddings", - total=len(data), + sparse_embeddings_iterable, disable=not progress_bar, desc="Calculating sparse embeddings", total=len(data) ): sparse_embeddings.append( - SparseEmbedding( - indices=sparse_embedding.indices.tolist(), - values=sparse_embedding.values.tolist(), - ) + SparseEmbedding(indices=sparse_embedding.indices.tolist(), values=sparse_embedding.values.tolist()) ) return sparse_embeddings diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index 332548cef..8b63582c5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -129,29 +129,21 @@ def warm_up(self): Initializes the component. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, - cache_dir=self.cache_dir, - threads=self.threads, - local_files_only=self.local_files_only, - ) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( - self.prefix - + self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) - + self.suffix + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix ) texts_to_embed.append(text_to_embed) @@ -166,11 +158,7 @@ def run(self, documents: List[Document]): :returns: A dictionary with the following keys: - `documents`: List of Documents with each Document's `embedding` field set to the computed embeddings. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "FastembedDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index b8eaffc4b..4b72389fa 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -117,26 +117,20 @@ def warm_up(self): Initializes the component. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, - cache_dir=self.cache_dir, - threads=self.threads, - local_files_only=self.local_files_only, - ) + self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] - text_to_embed = self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) + text_to_embed = self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) texts_to_embed.append(text_to_embed) return texts_to_embed @@ -151,11 +145,7 @@ def run(self, documents: List[Document]): - `documents`: List of Documents with each Document's `sparse_embedding` field set to the computed embeddings. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "FastembedSparseDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index dcdd99b5f..67348b2bd 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -81,13 +81,11 @@ def warm_up(self): Initializes the component. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, - cache_dir=self.cache_dir, - threads=self.threads, - local_files_only=self.local_files_only, - ) + self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) @component.output_types(sparse_embedding=SparseEmbedding) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index aec43f76f..a7f56ff97 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -88,13 +88,11 @@ def warm_up(self): Initializes the component. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, - cache_dir=self.cache_dir, - threads=self.threads, - local_files_only=self.local_files_only, - ) + self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) @component.output_types(embedding=List[float]) diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 185030317..631d9f1e0 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -5,13 +5,9 @@ ) -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_factory_behavior(mock_instructor): # noqa: ARG001 - embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name="BAAI/bge-small-en-v1.5" - ) + embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5") same_embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None ) @@ -26,30 +22,21 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 _FastembedEmbeddingBackendFactory._instances = {} -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_model_initialization(mock_instructor): _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-small-en-v1.5", ) mock_instructor.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", - cache_dir=None, - threads=None, - local_files_only=False, + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False ) # restore the factory state _FastembedEmbeddingBackendFactory._instances = {} -@patch( - "haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding" -) +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.TextEmbedding") def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 - embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name="BAAI/bge-small-en-v1.5" - ) + embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name="BAAI/bge-small-en-v1.5") data = ["sentence1", "sentence2"] embedding_backend.embed(data=data) diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index f7f9a5998..8afb89c69 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -190,10 +190,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", - cache_dir=None, - threads=None, - local_files_only=False, + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False ) @patch( @@ -215,9 +212,7 @@ def test_embed(self): """ embedder = FastembedDocumentEmbedder(model="BAAI/bge-base-en-v1.5") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( - len(x), 16 - ).tolist() # noqa: ARG005 + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 documents = [Document(content=f"Sample-document text {i}") for i in range(5)] @@ -263,12 +258,7 @@ def test_embed_metadata(self): ) embedder.embedding_backend = MagicMock() - documents = [ - Document( - content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"} - ) - for i in range(5) - ] + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] embedder.run(documents=documents) diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index e6f351c74..b4caca364 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -173,10 +173,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", - cache_dir=None, - threads=None, - local_files_only=False, + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False ) @patch( @@ -198,9 +195,7 @@ def _generate_mocked_sparse_embedding(self, n): random_indice_length = np.random.randint(3, 15) data = { "indices": list(range(random_indice_length)), - "values": [ - np.random.random_sample() for _ in range(random_indice_length) - ], + "values": [np.random.random_sample() for _ in range(random_indice_length)], } list_of_sparse_vectors.append(data) return list_of_sparse_vectors @@ -211,10 +206,8 @@ def test_embed(self): """ embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = ( - lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 - len(x) - ) + embedder.embedding_backend.embed = lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 + len(x) ) documents = [Document(content=f"Sample-document text {i}") for i in range(5)] @@ -264,12 +257,7 @@ def test_embed_metadata(self): ) embedder.embedding_backend = MagicMock() - documents = [ - Document( - content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"} - ) - for i in range(5) - ] + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] embedder.run(documents=documents) diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index a60481fa9..9e37df409 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -134,10 +134,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", - cache_dir=None, - threads=None, - local_files_only=False, + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False ) @patch( @@ -159,9 +156,7 @@ def _generate_mocked_sparse_embedding(self, n): random_indice_length = np.random.randint(3, 15) data = { "indices": list(range(random_indice_length)), - "values": [ - np.random.random_sample() for _ in range(random_indice_length) - ], + "values": [np.random.random_sample() for _ in range(random_indice_length)], } list_of_sparse_vectors.append(data) @@ -173,10 +168,8 @@ def test_embed(self): """ embedder = FastembedSparseTextEmbedder(model="BAAI/bge-base-en-v1.5") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = ( - lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 - len(x) - ) + embedder.embedding_backend.embed = lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 + len(x) ) text = "Good text to embed" @@ -198,9 +191,7 @@ def test_run_wrong_incorrect_format(self): list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="FastembedSparseTextEmbedder expects a string as input" - ): + with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) @pytest.mark.integration diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index cf62394f7..f20a98b57 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -153,10 +153,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", - cache_dir=None, - threads=None, - local_files_only=False, + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False ) @patch( @@ -178,9 +175,7 @@ def test_embed(self): """ embedder = FastembedTextEmbedder(model="BAAI/bge-base-en-v1.5") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( - len(x), 16 - ).tolist() # noqa: ARG005 + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 text = "Good text to embed" @@ -199,9 +194,7 @@ def test_run_wrong_incorrect_format(self): list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="FastembedTextEmbedder expects a string as input" - ): + with pytest.raises(TypeError, match="FastembedTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) @pytest.mark.integration diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 9ea2c086d..8b592a184 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -132,9 +132,7 @@ def __init__( self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) - def _generation_config_to_dict( - self, config: Union[GenerationConfig, Dict[str, Any]] - ) -> Dict[str, Any]: + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config return { @@ -170,18 +168,10 @@ def to_dict(self) -> Dict[str, Any]: # can't be easily serializated to a dict. We need to convert it to a protobuf class first. tool = tool.to_proto() # noqa: PLW2901 data["init_parameters"]["tools"].append(ToolProto.serialize(tool)) - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"][ - "generation_config" - ] = self._generation_config_to_dict(generation_config) - if ( - safety_settings := data["init_parameters"].get("safety_settings") - ) is not None: - data["init_parameters"]["safety_settings"] = { - k.value: v.value for k, v in safety_settings.items() - } + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: + data["init_parameters"]["safety_settings"] = {k.value: v.value for k, v in safety_settings.items()} return data @classmethod @@ -203,24 +193,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": # to be able to convert them to the Python class. proto = ToolProto.deserialize(tool) deserialized_tools.append( - Tool( - function_declarations=proto.function_declarations, - code_execution=proto.code_execution, - ) + Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution) ) data["init_parameters"]["tools"] = deserialized_tools - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"]["generation_config"] = GenerationConfig( - **generation_config - ) - if ( - safety_settings := data["init_parameters"].get("safety_settings") - ) is not None: + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) + if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: data["init_parameters"]["safety_settings"] = { - HarmCategory(k): HarmBlockThreshold(v) - for k, v in safety_settings.items() + HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } return default_from_dict(cls, data) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 48d3f5e52..f7b2f9097 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -99,9 +99,7 @@ def __init__( self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) - def _generation_config_to_dict( - self, config: Union[GenerationConfig, Dict[str, Any]] - ) -> Dict[str, Any]: + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config return { @@ -130,18 +128,10 @@ def to_dict(self) -> Dict[str, Any]: ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"][ - "generation_config" - ] = self._generation_config_to_dict(generation_config) - if ( - safety_settings := data["init_parameters"].get("safety_settings") - ) is not None: - data["init_parameters"]["safety_settings"] = { - k.value: v.value for k, v in safety_settings.items() - } + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: + data["init_parameters"]["safety_settings"] = {k.value: v.value for k, v in safety_settings.items()} return data @classmethod @@ -158,18 +148,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"]["generation_config"] = GenerationConfig( - **generation_config - ) - if ( - safety_settings := data["init_parameters"].get("safety_settings") - ) is not None: + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) + if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: data["init_parameters"]["safety_settings"] = { - HarmCategory(k): HarmBlockThreshold(v) - for k, v in safety_settings.items() + HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } return default_from_dict(cls, data) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index b0f3d78c2..9b3124eab 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -3,17 +3,10 @@ import pytest from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import ( - FunctionDeclaration, - HarmBlockThreshold, - HarmCategory, - Tool, -) +from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses.chat_message import ChatMessage -from haystack_integrations.components.generators.google_ai import ( - GoogleAIGeminiChatGenerator, -) +from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator def test_init(monkeypatch): @@ -27,19 +20,14 @@ def test_init(monkeypatch): top_p=0.5, top_k=0.5, ) - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", parameters={ "type_": "OBJECT", "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { "type_": "STRING", "enum": [ @@ -80,19 +68,14 @@ def test_to_dict(monkeypatch): top_p=0.5, top_k=2, ) - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", parameters={ "type_": "OBJECT", "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { "type_": "STRING", "enum": [ @@ -107,9 +90,7 @@ def test_to_dict(monkeypatch): tool = Tool(function_declarations=[get_current_weather_func]) - with patch( - "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" - ): + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator( generation_config=generation_config, safety_settings=safety_settings, @@ -118,11 +99,7 @@ def test_to_dict(monkeypatch): assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["GOOGLE_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, "model": "gemini-pro-vision", "generation_config": { "temperature": 0.5, @@ -145,18 +122,12 @@ def test_to_dict(monkeypatch): def test_from_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") - with patch( - "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" - ): + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator.from_dict( { "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["GOOGLE_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, "model": "gemini-pro-vision", "generation_config": { "temperature": 0.5, @@ -185,33 +156,21 @@ def test_from_dict(monkeypatch): top_p=0.5, top_k=2, ) - assert gemini._safety_settings == { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert len(gemini._tools) == 1 assert len(gemini._tools[0].function_declarations) == 1 assert gemini._tools[0].function_declarations[0].name == "get_current_weather" + assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location" assert ( - gemini._tools[0].function_declarations[0].description - == "Get the current weather in a given location" - ) - assert ( - gemini._tools[0] - .function_declarations[0] - .parameters.properties["location"] - .description + gemini._tools[0].function_declarations[0].parameters.properties["location"].description == "The city and state, e.g. San Francisco, CA" ) - assert gemini._tools[0].function_declarations[0].parameters.properties[ - "unit" - ].enum == ["celsius", "fahrenheit"] + assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"] assert gemini._tools[0].function_declarations[0].parameters.required == ["location"] assert isinstance(gemini._model, GenerativeModel) -@pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set" -) +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_run(): # We're ignoring the unused function argument check since we must have that argument for the test # to run successfully, but we don't actually use it. @@ -228,24 +187,18 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) - messages = [ - ChatMessage.from_user(content="What is the temperature in celsius in Berlin?") - ] + messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] res = gemini_chat.run(messages=messages) assert len(res["replies"]) > 0 weather = get_current_weather(**res["replies"][0].content) - messages += res["replies"] + [ - ChatMessage.from_function(content=weather, name="get_current_weather") - ] + messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] res = gemini_chat.run(messages=messages) assert len(res["replies"]) > 0 -@pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set" -) +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") messages = [ diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 3ef585da7..35c7d196b 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -6,9 +6,7 @@ from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory -from haystack_integrations.components.generators.google_ai import ( - GoogleAIGeminiGenerator, -) +from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator def test_init(monkeypatch): @@ -22,19 +20,14 @@ def test_init(monkeypatch): top_p=0.5, top_k=0.5, ) - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", parameters={ "type_": "OBJECT", "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { "type_": "STRING", "enum": [ @@ -48,9 +41,7 @@ def test_init(monkeypatch): ) tool = Tool(function_declarations=[get_current_weather_func]) - with patch( - "haystack_integrations.components.generators.google_ai.gemini.genai.configure" - ) as mock_genai_configure: + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, @@ -75,19 +66,14 @@ def test_to_dict(monkeypatch): top_p=0.5, top_k=2, ) - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", parameters={ "type_": "OBJECT", "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { "type_": "STRING", "enum": [ @@ -102,9 +88,7 @@ def test_to_dict(monkeypatch): tool = Tool(function_declarations=[get_current_weather_func]) - with patch( - "haystack_integrations.components.generators.google_ai.gemini.genai.configure" - ): + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, @@ -114,11 +98,7 @@ def test_to_dict(monkeypatch): "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", "init_parameters": { "model": "gemini-pro-vision", - "api_key": { - "env_vars": ["GOOGLE_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -140,9 +120,7 @@ def test_to_dict(monkeypatch): def test_from_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") - with patch( - "haystack_integrations.components.generators.google_ai.gemini.genai.configure" - ): + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator.from_dict( { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", @@ -175,9 +153,7 @@ def test_from_dict(monkeypatch): top_p=0.5, top_k=0.5, ) - assert gemini._safety_settings == { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - } + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert gemini._tools == [ Tool( function_declarations=[ @@ -208,9 +184,7 @@ def test_from_dict(monkeypatch): assert isinstance(gemini._model, GenerativeModel) -@pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set" -) +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_run(): gemini = GoogleAIGeminiGenerator(model="gemini-pro") res = gemini.run("Tell me something cool") diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py index 943ade1a1..14102eb4b 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py @@ -41,14 +41,7 @@ class VertexAIImageCaptioner: ``` """ - def __init__( - self, - *, - model: str = "imagetext", - project_id: str, - location: Optional[str] = None, - **kwargs - ): + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): """ Generate image captions using a Google Vertex AI model. @@ -81,11 +74,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ return default_to_dict( - self, - model=self._model_name, - project_id=self._project_id, - location=self._location, - **self._kwargs + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 778424422..f08a69b5f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -99,14 +99,10 @@ def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: return { - "function_declarations": [ - self._function_to_dict(f) for f in tool._raw_tool.function_declarations - ], + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], } - def _generation_config_to_dict( - self, config: Union[GenerationConfig, Dict[str, Any]] - ) -> Dict[str, Any]: + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config return { @@ -136,12 +132,8 @@ def to_dict(self) -> Dict[str, Any]: ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"][ - "generation_config" - ] = self._generation_config_to_dict(generation_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @classmethod @@ -156,12 +148,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator": """ if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"]["generation_config"] = GenerationConfig.from_dict( - generation_config - ) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) return default_from_dict(cls, data) @@ -185,9 +173,7 @@ def _message_to_part(self, message: ChatMessage) -> Part: elif message.role == ChatRole.SYSTEM: return Part.from_text(message.content) elif message.role == ChatRole.FUNCTION: - return Part.from_function_response( - name=message.name, response=message.content - ) + return Part.from_function_response(name=message.name, response=message.content) elif message.role == ChatRole.USER: return self._convert_part(message.content) @@ -199,9 +185,7 @@ def _message_to_content(self, message: ChatMessage) -> Content: elif message.role == ChatRole.SYSTEM: part = Part.from_text(message.content) elif message.role == ChatRole.FUNCTION: - part = Part.from_function_response( - name=message.name, response=message.content - ) + part = Part.from_function_response(name=message.name, response=message.content) elif message.role == ChatRole.USER: part = self._convert_part(message.content) else: diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py index 42145e6d6..c39c7f88b 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py @@ -45,14 +45,7 @@ class VertexAICodeGenerator: ``` """ - def __init__( - self, - *, - model: str = "code-bison", - project_id: str, - location: Optional[str] = None, - **kwargs - ): + def __init__(self, *, model: str = "code-bison", project_id: str, location: Optional[str] = None, **kwargs): """ Generate code using a Google Vertex AI model. @@ -84,11 +77,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ return default_to_dict( - self, - model=self._model_name, - project_id=self._project_id, - location=self._location, - **self._kwargs + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod @@ -115,9 +104,5 @@ def run(self, prefix: str, suffix: Optional[str] = None): """ res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs) # Handle the case where the model returns multiple candidates - replies = ( - [c.text for c in res.candidates] - if hasattr(res, "candidates") - else [res.text] - ) + replies = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text] return {"replies": replies} diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 36c82b271..8a288a315 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -110,14 +110,10 @@ def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: return { - "function_declarations": [ - self._function_to_dict(f) for f in tool._raw_tool.function_declarations - ], + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], } - def _generation_config_to_dict( - self, config: Union[GenerationConfig, Dict[str, Any]] - ) -> Dict[str, Any]: + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config return { @@ -147,12 +143,8 @@ def to_dict(self) -> Dict[str, Any]: ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"][ - "generation_config" - ] = self._generation_config_to_dict(generation_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @classmethod @@ -167,12 +159,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": """ if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] - if ( - generation_config := data["init_parameters"].get("generation_config") - ) is not None: - data["init_parameters"]["generation_config"] = GenerationConfig.from_dict( - generation_config - ) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) return default_from_dict(cls, data) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py index 9706b2c26..ae8c4892f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py @@ -30,14 +30,7 @@ class VertexAIImageGenerator: ``` """ - def __init__( - self, - *, - model: str = "imagegeneration", - project_id: str, - location: Optional[str] = None, - **kwargs - ): + def __init__(self, *, model: str = "imagegeneration", project_id: str, location: Optional[str] = None, **kwargs): """ Generates images using a Google Vertex AI model. @@ -69,11 +62,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ return default_to_dict( - self, - model=self._model_name, - project_id=self._project_id, - location=self._location, - **self._kwargs + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod @@ -99,11 +88,6 @@ def run(self, prompt: str, negative_prompt: Optional[str] = None): - `images`: A list of ByteStream objects, each containing an image. """ negative_prompt = negative_prompt or self._kwargs.get("negative_prompt") - res = self._model.generate_images( - prompt=prompt, negative_prompt=negative_prompt, **self._kwargs - ) - images = [ - ByteStream(data=i._image_bytes, meta=i.generation_parameters) - for i in res.images - ] + res = self._model.generate_images(prompt=prompt, negative_prompt=negative_prompt, **self._kwargs) + images = [ByteStream(data=i._image_bytes, meta=i.generation_parameters) for i in res.images] return {"images": images} diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py index 98e926e9f..392a41e00 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py @@ -35,14 +35,7 @@ class VertexAIImageQA: ``` """ - def __init__( - self, - *, - model: str = "imagetext", - project_id: str, - location: Optional[str] = None, - **kwargs - ): + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): """ Answers questions about an image using a Google Vertex AI model. @@ -74,11 +67,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ return default_to_dict( - self, - model=self._model_name, - project_id=self._project_id, - location=self._location, - **self._kwargs + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) @classmethod @@ -102,7 +91,5 @@ def run(self, image: ByteStream, question: str): :returns: A dictionary with the following keys: - `replies`: A list of answers to the question. """ - replies = self._model.ask_question( - image=Image(image.data), question=question, **self._kwargs - ) + replies = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs) return {"replies": replies} diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py index 40749b8dc..59061d91c 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py @@ -45,14 +45,7 @@ class VertexAITextGenerator: ``` """ - def __init__( - self, - *, - model: str = "text-bison", - project_id: str, - location: Optional[str] = None, - **kwargs, - ): + def __init__(self, *, model: str = "text-bison", project_id: str, location: Optional[str] = None, **kwargs): """ Generate text using a Google Vertex AI model. @@ -84,25 +77,13 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ data = default_to_dict( - self, - model=self._model_name, - project_id=self._project_id, - location=self._location, - **self._kwargs, + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs ) - if ( - grounding_source := data["init_parameters"].get("grounding_source") - ) is not None: + if (grounding_source := data["init_parameters"].get("grounding_source")) is not None: # Handle the grounding source dataclasses - class_type = ( - f"{grounding_source.__module__}.{grounding_source.__class__.__name__}" - ) - init_fields = { - f.name: getattr(grounding_source, f.name) - for f in fields(grounding_source) - if f.init - } + class_type = f"{grounding_source.__module__}.{grounding_source.__class__.__name__}" + init_fields = {f.name: getattr(grounding_source, f.name) for f in fields(grounding_source) if f.init} data["init_parameters"]["grounding_source"] = { "type": class_type, "init_parameters": init_fields, @@ -120,9 +101,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator": :returns: Deserialized component. """ - if ( - grounding_source := data["init_parameters"].get("grounding_source") - ) is not None: + if (grounding_source := data["init_parameters"].get("grounding_source")) is not None: module_name, class_name = grounding_source["type"].rsplit(".", 1) module = importlib.import_module(module_name) data["init_parameters"]["grounding_source"] = getattr(module, class_name)( @@ -130,11 +109,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator": ) return default_from_dict(cls, data) - @component.output_types( - replies=List[str], - safety_attributes=Dict[str, float], - citations=List[Dict[str, Any]], - ) + @component.output_types(replies=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]]) def run(self, prompt: str): """Prompts the model to generate text. @@ -156,8 +131,4 @@ def run(self, prompt: str): safety_attributes.append(prediction["safetyAttributes"]) citations.append(prediction["citationMetadata"]["citations"]) - return { - "replies": replies, - "safety_attributes": safety_attributes, - "citations": citations, - } + return {"replies": replies, "safety_attributes": safety_attributes, "citations": citations} diff --git a/integrations/google_vertex/tests/test_captioner.py b/integrations/google_vertex/tests/test_captioner.py index 87fe57f3b..26249dbee 100644 --- a/integrations/google_vertex/tests/test_captioner.py +++ b/integrations/google_vertex/tests/test_captioner.py @@ -2,25 +2,16 @@ from haystack.dataclasses.byte_stream import ByteStream -from haystack_integrations.components.generators.google_vertex import ( - VertexAIImageCaptioner, -) +from haystack_integrations.components.generators.google_vertex import VertexAIImageCaptioner @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") -@patch( - "haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_init(mock_model_class, mock_vertexai): captioner = VertexAIImageCaptioner( - model="imagetext", - project_id="myproject-123456", - number_of_results=1, - language="it", - ) - mock_vertexai.init.assert_called_once_with( - project="myproject-123456", location=None + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) mock_model_class.from_pretrained.assert_called_once_with("imagetext") assert captioner._model_name == "imagetext" assert captioner._project_id == "myproject-123456" @@ -29,15 +20,10 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") -@patch( - "haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): captioner = VertexAIImageCaptioner( - model="imagetext", - project_id="myproject-123456", - number_of_results=1, - language="it", + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" ) assert captioner.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", @@ -52,9 +38,7 @@ def test_to_dict(_mock_model_class, _mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") -@patch( - "haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_from_dict(_mock_model_class, _mock_vertexai): captioner = VertexAIImageCaptioner.from_dict( { @@ -75,17 +59,12 @@ def test_from_dict(_mock_model_class, _mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") -@patch( - "haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model_class.from_pretrained.return_value = mock_model captioner = VertexAIImageCaptioner( - model="imagetext", - project_id="myproject-123456", - number_of_results=1, - language="it", + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" ) image = ByteStream(data=b"image data") diff --git a/integrations/google_vertex/tests/test_code_generator.py b/integrations/google_vertex/tests/test_code_generator.py index ebae42fd7..129954062 100644 --- a/integrations/google_vertex/tests/test_code_generator.py +++ b/integrations/google_vertex/tests/test_code_generator.py @@ -2,27 +2,16 @@ from vertexai.language_models import TextGenerationResponse -from haystack_integrations.components.generators.google_vertex import ( - VertexAICodeGenerator, -) +from haystack_integrations.components.generators.google_vertex import VertexAICodeGenerator -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAICodeGenerator( - model="code-bison", - project_id="myproject-123456", - candidate_count=3, - temperature=0.5, - ) - mock_vertexai.init.assert_called_once_with( - project="myproject-123456", location=None + model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) mock_model_class.from_pretrained.assert_called_once_with("code-bison") assert generator._model_name == "code-bison" assert generator._project_id == "myproject-123456" @@ -30,18 +19,11 @@ def test_init(mock_model_class, mock_vertexai): assert generator._kwargs == {"candidate_count": 3, "temperature": 0.5} -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAICodeGenerator( - model="code-bison", - project_id="myproject-123456", - candidate_count=3, - temperature=0.5, + model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 ) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", @@ -55,12 +37,8 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAICodeGenerator.from_dict( { @@ -80,21 +58,14 @@ def test_from_dict(_mock_model_class, _mock_vertexai): assert generator._model is not None -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_run_calls_predict(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = TextGenerationResponse("answer", None) mock_model_class.from_pretrained.return_value = mock_model generator = VertexAICodeGenerator( - model="code-bison", - project_id="myproject-123456", - candidate_count=1, - temperature=0.5, + model="code-bison", project_id="myproject-123456", candidate_count=1, temperature=0.5 ) prefix = "def print_json(data):\n" diff --git a/integrations/google_vertex/tests/test_image_generator.py b/integrations/google_vertex/tests/test_image_generator.py index e196f68db..42cc0a0a3 100644 --- a/integrations/google_vertex/tests/test_image_generator.py +++ b/integrations/google_vertex/tests/test_image_generator.py @@ -2,17 +2,11 @@ from vertexai.preview.vision_models import ImageGenerationResponse -from haystack_integrations.components.generators.google_vertex import ( - VertexAIImageGenerator, -) +from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", @@ -20,9 +14,7 @@ def test_init(mock_model_class, mock_vertexai): guidance_scale=12, number_of_images=3, ) - mock_vertexai.init.assert_called_once_with( - project="myproject-123456", location=None - ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) mock_model_class.from_pretrained.assert_called_once_with("imagetext") assert generator._model_name == "imagetext" assert generator._project_id == "myproject-123456" @@ -33,12 +25,8 @@ def test_init(mock_model_class, mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", @@ -58,12 +46,8 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator.from_dict( { @@ -86,12 +70,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_run_calls_generate_images(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.generate_images.return_value = ImageGenerationResponse(images=[]) @@ -108,8 +88,5 @@ def test_run_calls_generate_images(mock_model_class, _mock_vertexai): generator.run(prompt=prompt, negative_prompt=negative_prompt) mock_model.generate_images.assert_called_once_with( - prompt=prompt, - negative_prompt=negative_prompt, - guidance_scale=12, - number_of_images=3, + prompt=prompt, negative_prompt=negative_prompt, guidance_scale=12, number_of_images=3 ) diff --git a/integrations/google_vertex/tests/test_question_answering.py b/integrations/google_vertex/tests/test_question_answering.py index e86bc71bc..3f414f0e0 100644 --- a/integrations/google_vertex/tests/test_question_answering.py +++ b/integrations/google_vertex/tests/test_question_answering.py @@ -5,21 +5,15 @@ from haystack_integrations.components.generators.google_vertex import VertexAIImageQA -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAIImageQA( model="imagetext", project_id="myproject-123456", number_of_results=3, ) - mock_vertexai.init.assert_called_once_with( - project="myproject-123456", location=None - ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) mock_model_class.from_pretrained.assert_called_once_with("imagetext") assert generator._model_name == "imagetext" assert generator._project_id == "myproject-123456" @@ -27,12 +21,8 @@ def test_init(mock_model_class, mock_vertexai): assert generator._kwargs == {"number_of_results": 3} -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA( model="imagetext", @@ -50,12 +40,8 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA.from_dict( { @@ -74,12 +60,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): assert generator._kwargs == {"number_of_results": 3} -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel" -) +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_run_calls_ask_question(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.ask_question.return_value = [] diff --git a/integrations/google_vertex/tests/test_text_generator.py b/integrations/google_vertex/tests/test_text_generator.py index a54072456..3e5248dc7 100644 --- a/integrations/google_vertex/tests/test_text_generator.py +++ b/integrations/google_vertex/tests/test_text_generator.py @@ -2,51 +2,30 @@ from vertexai.language_models import GroundingSource -from haystack_integrations.components.generators.google_vertex import ( - VertexAITextGenerator, -) +from haystack_integrations.components.generators.google_vertex import VertexAITextGenerator -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_init(mock_model_class, mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") generator = VertexAITextGenerator( - model="text-bison", - project_id="myproject-123456", - temperature=0.2, - grounding_source=grounding_source, - ) - mock_vertexai.init.assert_called_once_with( - project="myproject-123456", location=None + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) mock_model_class.from_pretrained.assert_called_once_with("text-bison") assert generator._model_name == "text-bison" assert generator._project_id == "myproject-123456" assert generator._location is None - assert generator._kwargs == { - "temperature": 0.2, - "grounding_source": grounding_source, - } + assert generator._kwargs == {"temperature": 0.2, "grounding_source": grounding_source} -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") generator = VertexAITextGenerator( - model="text-bison", - project_id="myproject-123456", - temperature=0.2, - grounding_source=grounding_source, + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source ) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", @@ -68,12 +47,8 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAITextGenerator.from_dict( { @@ -104,27 +79,18 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.vertexai" -) -@patch( - "haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel" -) +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = MagicMock() mock_model_class.from_pretrained.return_value = mock_model grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") generator = VertexAITextGenerator( - model="text-bison", - project_id="myproject-123456", - temperature=0.2, - grounding_source=grounding_source, + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source ) prompt = "What is the answer?" generator.run(prompt=prompt) - mock_model.predict.assert_called_once_with( - prompt=prompt, temperature=0.2, grounding_source=grounding_source - ) + mock_model.predict.assert_called_once_with(prompt=prompt, temperature=0.2, grounding_source=grounding_source) diff --git a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py index 15ab47618..a868c6c1b 100644 --- a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py +++ b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py @@ -57,12 +57,8 @@ def __init__( *, model: str = "bge-large", batch_size: int = 32_768, - access_token: Secret = Secret.from_env_var( - "GRADIENT_ACCESS_TOKEN" - ), # noqa: B008 - workspace_id: Secret = Secret.from_env_var( - "GRADIENT_WORKSPACE_ID" - ), # noqa: B008 + access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 + workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 host: Optional[str] = None, progress_bar: bool = True, ) -> None: @@ -84,9 +80,7 @@ def __init__( self._workspace_id = workspace_id self._gradient = Gradient( - access_token=access_token.resolve_value(), - workspace_id=workspace_id.resolve_value(), - host=host, + access_token=access_token.resolve_value(), workspace_id=workspace_id.resolve_value(), host=host ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -122,9 +116,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GradientDocumentEmbedder": :returns: The deserialized component instance. """ - deserialize_secrets_inplace( - data["init_parameters"], keys=["access_token", "workspace_id"] - ) + deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) return default_from_dict(cls, data) def warm_up(self) -> None: @@ -132,21 +124,14 @@ def warm_up(self) -> None: Initializes the component. """ if not hasattr(self, "_embedding_model"): - self._embedding_model = self._gradient.get_embeddings_model( - slug=self._model_name - ) + self._embedding_model = self._gradient.get_embeddings_model(slug=self._model_name) - def _generate_embeddings( - self, documents: List[Document], batch_size: int - ) -> List[List[float]]: + def _generate_embeddings(self, documents: List[Document], batch_size: int) -> List[List[float]]: """ Batches the documents and generates the embeddings. """ if self._progress_bar and tqdm_imported: - batches = [ - documents[i : i + batch_size] - for i in range(0, len(documents), batch_size) - ] + batches = [documents[i : i + batch_size] for i in range(0, len(documents), batch_size)] progress_bar = tqdm else: # no progress bar @@ -155,9 +140,7 @@ def _generate_embeddings( embeddings = [] for batch in progress_bar(batches): - response = self._embedding_model.embed( - inputs=[{"input": doc.content} for doc in batch] - ) + response = self._embedding_model.embed(inputs=[{"input": doc.content} for doc in batch]) embeddings.extend([e.embedding for e in response.embeddings]) return embeddings @@ -175,11 +158,7 @@ def run(self, documents: List[Document]): - `documents`: The embedded Documents. """ - if ( - not isinstance(documents, list) - or documents - and any(not isinstance(doc, Document) for doc in documents) - ): + if not isinstance(documents, list) or documents and any(not isinstance(doc, Document) for doc in documents): msg = "GradientDocumentEmbedder expects a list of Documents as input.\ In case you want to embed a list of strings, please use the GradientTextEmbedder." raise TypeError(msg) @@ -188,9 +167,7 @@ def run(self, documents: List[Document]): msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - embeddings = self._generate_embeddings( - documents=documents, batch_size=self._batch_size - ) + embeddings = self._generate_embeddings(documents=documents, batch_size=self._batch_size) for doc, embedding in zip(documents, embeddings): doc.embedding = embedding diff --git a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py index fd27c686f..3bcbb4db6 100644 --- a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py +++ b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py @@ -31,12 +31,8 @@ def __init__( self, *, model: str = "bge-large", - access_token: Secret = Secret.from_env_var( - "GRADIENT_ACCESS_TOKEN" - ), # noqa: B008 - workspace_id: Secret = Secret.from_env_var( - "GRADIENT_WORKSPACE_ID" - ), # noqa: B008 + access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 + workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 host: Optional[str] = None, ) -> None: """ @@ -53,9 +49,7 @@ def __init__( self._workspace_id = workspace_id self._gradient = Gradient( - host=host, - access_token=access_token.resolve_value(), - workspace_id=workspace_id.resolve_value(), + host=host, access_token=access_token.resolve_value(), workspace_id=workspace_id.resolve_value() ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -88,9 +82,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GradientTextEmbedder": :returns: The deserialized component instance. """ - deserialize_secrets_inplace( - data["init_parameters"], keys=["access_token", "workspace_id"] - ) + deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) return default_from_dict(cls, data) def warm_up(self) -> None: @@ -98,9 +90,7 @@ def warm_up(self) -> None: Initializes the component. """ if not hasattr(self, "_embedding_model"): - self._embedding_model = self._gradient.get_embeddings_model( - slug=self._model_name - ) + self._embedding_model = self._gradient.get_embeddings_model(slug=self._model_name) @component.output_types(embedding=List[float]) def run(self, text: str): diff --git a/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py b/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py index e05e87ece..71b39d309 100644 --- a/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py +++ b/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py @@ -30,9 +30,7 @@ class GradientGenerator: def __init__( self, *, - access_token: Secret = Secret.from_env_var( - "GRADIENT_ACCESS_TOKEN" - ), # noqa: B008 + access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 base_model_slug: Optional[str] = None, host: Optional[str] = None, max_generated_token_count: Optional[int] = None, @@ -40,9 +38,7 @@ def __init__( temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - workspace_id: Secret = Secret.from_env_var( - "GRADIENT_WORKSPACE_ID" - ), # noqa: B008 + workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 ) -> None: """ Create a GradientGenerator component. @@ -85,9 +81,7 @@ def __init__( self._model_adapter_id = model_adapter_id self._gradient = Gradient( - access_token=access_token.resolve_value(), - host=host, - workspace_id=workspace_id.resolve_value(), + access_token=access_token.resolve_value(), host=host, workspace_id=workspace_id.resolve_value() ) def to_dict(self) -> Dict[str, Any]: @@ -120,9 +114,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GradientGenerator": The deserialized component instance. """ - deserialize_secrets_inplace( - data["init_parameters"], keys=["access_token", "workspace_id"] - ) + deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) return default_from_dict(cls, data) def warm_up(self): @@ -131,13 +123,9 @@ def warm_up(self): """ if not hasattr(self, "_model"): if isinstance(self._base_model_slug, str): - self._model = self._gradient.get_base_model( - base_model_slug=self._base_model_slug - ) + self._model = self._gradient.get_base_model(base_model_slug=self._base_model_slug) if isinstance(self._model_adapter_id, str): - self._model = self._gradient.get_model_adapter( - model_adapter_id=self._model_adapter_id - ) + self._model = self._gradient.get_model_adapter(model_adapter_id=self._model_adapter_id) @component.output_types(replies=List[str]) def run(self, prompt: str): diff --git a/integrations/gradient/tests/test_gradient_document_embedder.py b/integrations/gradient/tests/test_gradient_document_embedder.py index 14bcc5bda..3bc739f3e 100644 --- a/integrations/gradient/tests/test_gradient_document_embedder.py +++ b/integrations/gradient/tests/test_gradient_document_embedder.py @@ -2,9 +2,7 @@ import numpy as np import pytest -from gradientai.openapi.client.models.generate_embedding_success import ( - GenerateEmbeddingSuccess, -) +from gradientai.openapi.client.models.generate_embedding_success import GenerateEmbeddingSuccess from haystack import Document from haystack.utils import Secret @@ -23,6 +21,7 @@ def tokens_from_env(monkeypatch): class TestGradientDocumentEmbedder: def test_init_from_env(self, tokens_from_env): + embedder = GradientDocumentEmbedder() assert embedder is not None assert embedder._gradient.workspace_id == workspace_id @@ -42,8 +41,7 @@ def test_init_without_workspace(self, monkeypatch): def test_init_from_params(self): embedder = GradientDocumentEmbedder( - access_token=Secret.from_token(access_token), - workspace_id=Secret.from_token(workspace_id), + access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) ) assert embedder is not None assert embedder._gradient.workspace_id == workspace_id @@ -54,8 +52,7 @@ def test_init_from_params_precedence(self, monkeypatch): monkeypatch.setenv("GRADIENT_WORKSPACE_ID", "env_workspace_id") embedder = GradientDocumentEmbedder( - access_token=Secret.from_token(access_token), - workspace_id=Secret.from_token(workspace_id), + access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) ) assert embedder is not None assert embedder._gradient.workspace_id == workspace_id @@ -68,20 +65,12 @@ def test_to_dict(self, tokens_from_env): assert data == { "type": t, "init_parameters": { - "access_token": { - "env_vars": ["GRADIENT_ACCESS_TOKEN"], - "strict": True, - "type": "env_var", - }, + "access_token": {"env_vars": ["GRADIENT_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "batch_size": 32768, "host": None, "model": "bge-large", "progress_bar": True, - "workspace_id": { - "env_vars": ["GRADIENT_WORKSPACE_ID"], - "strict": True, - "type": "env_var", - }, + "workspace_id": {"env_vars": ["GRADIENT_WORKSPACE_ID"], "strict": True, "type": "env_var"}, }, } @@ -89,37 +78,26 @@ def test_warmup(self, tokens_from_env): embedder = GradientDocumentEmbedder() embedder._gradient.get_embeddings_model = MagicMock() embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with( - slug="bge-large" - ) + embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") def test_warmup_doesnt_reload(self, tokens_from_env): embedder = GradientDocumentEmbedder() - embedder._gradient.get_embeddings_model = MagicMock( - default_return_value="fake model" - ) + embedder._gradient.get_embeddings_model = MagicMock(default_return_value="fake model") embedder.warm_up() embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with( - slug="bge-large" - ) + embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") def test_run_fail_if_not_warmed_up(self, tokens_from_env): embedder = GradientDocumentEmbedder() with pytest.raises(RuntimeError, match="warm_up()"): - embedder.run( - documents=[Document(content=f"document number {i}") for i in range(5)] - ) + embedder.run(documents=[Document(content=f"document number {i}") for i in range(5)]) def test_run(self, tokens_from_env): embedder = GradientDocumentEmbedder() embedder._embedding_model = NonCallableMagicMock() embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[ - {"embedding": np.random.rand(1024).tolist(), "index": i} - for i in range(5) - ] + embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(5)] ) documents = [Document(content=f"document number {i}") for i in range(5)] @@ -139,10 +117,7 @@ def test_run_batch(self, tokens_from_env): embedder._embedding_model = NonCallableMagicMock() embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[ - {"embedding": np.random.rand(1024).tolist(), "index": i} - for i in range(110) - ] + embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(110)] ) documents = [Document(content=f"document number {i}") for i in range(110)] @@ -163,15 +138,10 @@ def test_run_custom_batch(self, tokens_from_env): document_count = 101 embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[ - {"embedding": np.random.rand(1024).tolist(), "index": i} - for i in range(document_count) - ] + embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(document_count)] ) - documents = [ - Document(content=f"document number {i}") for i in range(document_count) - ] + documents = [Document(content=f"document number {i}") for i in range(document_count)] result = embedder.run(documents=documents) diff --git a/integrations/gradient/tests/test_gradient_rag_pipelines.py b/integrations/gradient/tests/test_gradient_rag_pipelines.py index d29067376..89ec7cfb2 100644 --- a/integrations/gradient/tests/test_gradient_rag_pipelines.py +++ b/integrations/gradient/tests/test_gradient_rag_pipelines.py @@ -9,17 +9,13 @@ from haystack.components.writers import DocumentWriter from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.gradient import ( - GradientDocumentEmbedder, - GradientTextEmbedder, -) +from haystack_integrations.components.embedders.gradient import GradientDocumentEmbedder, GradientTextEmbedder from haystack_integrations.components.generators.gradient import GradientGenerator @pytest.mark.integration @pytest.mark.skipif( - not os.environ.get("GRADIENT_ACCESS_TOKEN", None) - or not os.environ.get("GRADIENT_WORKSPACE_ID", None), + not os.environ.get("GRADIENT_ACCESS_TOKEN", None) or not os.environ.get("GRADIENT_WORKSPACE_ID", None), reason="Export env variables called GRADIENT_ACCESS_TOKEN and GRADIENT_WORKSPACE_ID \ containing the Gradient configuration settings to run this test.", ) @@ -38,15 +34,10 @@ def test_gradient_embedding_retrieval_rag_pipeline(tmp_path): embedder = GradientTextEmbedder() rag_pipeline.add_component(instance=embedder, name="text_embedder") rag_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), - name="retriever", - ) - rag_pipeline.add_component( - instance=PromptBuilder(template=prompt_template), name="prompt_builder" - ) - rag_pipeline.add_component( - instance=GradientGenerator(base_model_slug="llama2-7b-chat"), name="llm" + instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever" ) + rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") + rag_pipeline.add_component(instance=GradientGenerator(base_model_slug="llama2-7b-chat"), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("text_embedder", "retriever") rag_pipeline.connect("retriever", "prompt_builder.documents") @@ -73,12 +64,8 @@ def test_gradient_embedding_retrieval_rag_pipeline(tmp_path): ] document_store = rag_pipeline.get_component("retriever").document_store indexing_pipeline = Pipeline() - indexing_pipeline.add_component( - instance=GradientDocumentEmbedder(), name="document_embedder" - ) - indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store), name="document_writer" - ) + indexing_pipeline.add_component(instance=GradientDocumentEmbedder(), name="document_embedder") + indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer") indexing_pipeline.connect("document_embedder", "document_writer") indexing_pipeline.run({"document_embedder": {"documents": documents}}) diff --git a/integrations/gradient/tests/test_gradient_text_embedder.py b/integrations/gradient/tests/test_gradient_text_embedder.py index 4e7458f4d..b12587994 100644 --- a/integrations/gradient/tests/test_gradient_text_embedder.py +++ b/integrations/gradient/tests/test_gradient_text_embedder.py @@ -2,9 +2,7 @@ import numpy as np import pytest -from gradientai.openapi.client.models.generate_embedding_success import ( - GenerateEmbeddingSuccess, -) +from gradientai.openapi.client.models.generate_embedding_success import GenerateEmbeddingSuccess from haystack.utils import Secret from haystack_integrations.components.embedders.gradient import GradientTextEmbedder @@ -41,8 +39,7 @@ def test_init_without_workspace(self, monkeypatch): def test_init_from_params(self): embedder = GradientTextEmbedder( - access_token=Secret.from_token(access_token), - workspace_id=Secret.from_token(workspace_id), + access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) ) assert embedder is not None assert embedder._gradient.workspace_id == workspace_id @@ -53,8 +50,7 @@ def test_init_from_params_precedence(self, monkeypatch): monkeypatch.setenv("GRADIENT_WORKSPACE_ID", "env_workspace_id") embedder = GradientTextEmbedder( - access_token=Secret.from_token(access_token), - workspace_id=Secret.from_token(workspace_id), + access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) ) assert embedder is not None assert embedder._gradient.workspace_id == workspace_id @@ -66,18 +62,10 @@ def test_to_dict(self, tokens_from_env): assert data == { "type": "haystack_integrations.components.embedders.gradient.gradient_text_embedder.GradientTextEmbedder", "init_parameters": { - "access_token": { - "env_vars": ["GRADIENT_ACCESS_TOKEN"], - "strict": True, - "type": "env_var", - }, + "access_token": {"env_vars": ["GRADIENT_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "host": None, "model": "bge-large", - "workspace_id": { - "env_vars": ["GRADIENT_WORKSPACE_ID"], - "strict": True, - "type": "env_var", - }, + "workspace_id": {"env_vars": ["GRADIENT_WORKSPACE_ID"], "strict": True, "type": "env_var"}, }, } @@ -85,20 +73,14 @@ def test_warmup(self, tokens_from_env): embedder = GradientTextEmbedder() embedder._gradient.get_embeddings_model = MagicMock() embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with( - slug="bge-large" - ) + embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") def test_warmup_doesnt_reload(self, tokens_from_env): embedder = GradientTextEmbedder() - embedder._gradient.get_embeddings_model = MagicMock( - default_return_value="fake model" - ) + embedder._gradient.get_embeddings_model = MagicMock(default_return_value="fake model") embedder.warm_up() embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with( - slug="bge-large" - ) + embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") def test_run_fail_if_not_warmed_up(self, tokens_from_env): embedder = GradientTextEmbedder() @@ -109,15 +91,11 @@ def test_run_fail_if_not_warmed_up(self, tokens_from_env): def test_run_fail_when_no_embeddings_returned(self, tokens_from_env): embedder = GradientTextEmbedder() embedder._embedding_model = NonCallableMagicMock() - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[] - ) + embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess(embeddings=[]) with pytest.raises(RuntimeError): _result = embedder.run(text="The food was delicious") - embedder._embedding_model.embed.assert_called_once_with( - inputs=[{"input": "The food was delicious"}] - ) + embedder._embedding_model.embed.assert_called_once_with(inputs=[{"input": "The food was delicious"}]) def test_run_empty_string(self, tokens_from_env): embedder = GradientTextEmbedder() @@ -140,9 +118,7 @@ def test_run(self, tokens_from_env): ) result = embedder.run(text="The food was delicious") - embedder._embedding_model.embed.assert_called_once_with( - inputs=[{"input": "The food was delicious"}] - ) + embedder._embedding_model.embed.assert_called_once_with(inputs=[{"input": "The food was delicious"}]) assert len(result["embedding"]) == 1024 # 1024 is the bge-large embedding size assert all(isinstance(x, float) for x in result["embedding"]) diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py index bab9aa512..717534aba 100644 --- a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py @@ -15,20 +15,14 @@ class _InstructorEmbeddingBackendFactory: _instances: ClassVar[Dict[str, "_InstructorEmbeddingBackend"]] = {} @staticmethod - def get_embedding_backend( - model: str, device: Optional[str] = None, token: Optional[Secret] = None - ): + def get_embedding_backend(model: str, device: Optional[str] = None, token: Optional[Secret] = None): embedding_backend_id = f"{model}{device}{token}" if embedding_backend_id in _InstructorEmbeddingBackendFactory._instances: return _InstructorEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _InstructorEmbeddingBackend( - model=model, device=device, token=token - ) - _InstructorEmbeddingBackendFactory._instances[ - embedding_backend_id - ] = embedding_backend + embedding_backend = _InstructorEmbeddingBackend(model=model, device=device, token=token) + _InstructorEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -37,9 +31,7 @@ class _InstructorEmbeddingBackend: Class to manage INSTRUCTOR embeddings. """ - def __init__( - self, model: str, device: Optional[str] = None, token: Optional[Secret] = None - ): + def __init__(self, model: str, device: Optional[str] = None, token: Optional[Secret] = None): self.model = INSTRUCTOR( model_name_or_path=model, device=device, diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py index 31632388d..734798f46 100644 --- a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py @@ -63,9 +63,7 @@ def __init__( self, model: str = "hkunlp/instructor-base", device: Optional[ComponentDevice] = None, - token: Optional[Secret] = Secret.from_env_var( - "HF_API_TOKEN", strict=False - ), # noqa: B008 + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 instruction: str = "Represent the document", batch_size: int = 32, progress_bar: bool = True, @@ -149,12 +147,8 @@ def warm_up(self): Initializes the component. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _InstructorEmbeddingBackendFactory.get_embedding_backend( - model=self.model, - device=self.device.to_torch_str(), - token=self.token, - ) + self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( + model=self.model, device=self.device.to_torch_str(), token=self.token ) @component.output_types(documents=List[Document]) @@ -164,11 +158,7 @@ def run(self, documents: List[Document]): param documents: A list of Documents to embed. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "InstructorDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the InstructorTextEmbedder." @@ -182,15 +172,11 @@ def run(self, documents: List[Document]): texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = [ self.instruction, - self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ), + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]), ] texts_to_embed.append(text_to_embed) diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py index 64e47a11a..46132a8aa 100644 --- a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py @@ -42,9 +42,7 @@ def __init__( self, model: str = "hkunlp/instructor-base", device: Optional[ComponentDevice] = None, - token: Optional[Secret] = Secret.from_env_var( - "HF_API_TOKEN", strict=False - ), # noqa: B008 + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 instruction: str = "Represent the sentence", batch_size: int = 32, progress_bar: bool = True, @@ -113,12 +111,8 @@ def warm_up(self): Load the embedding backend. """ if not hasattr(self, "embedding_backend"): - self.embedding_backend = ( - _InstructorEmbeddingBackendFactory.get_embedding_backend( - model=self.model, - device=self.device.to_torch_str(), - token=self.token, - ) + self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( + model=self.model, device=self.device.to_torch_str(), token=self.token ) @component.output_types(embedding=List[float]) diff --git a/integrations/instructor_embedders/tests/test_instructor_backend.py b/integrations/instructor_embedders/tests/test_instructor_backend.py index 0a966ad32..85c1f012a 100644 --- a/integrations/instructor_embedders/tests/test_instructor_backend.py +++ b/integrations/instructor_embedders/tests/test_instructor_backend.py @@ -13,13 +13,9 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( model="hkunlp/instructor-large", device="cpu" ) - same_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( - "hkunlp/instructor-large", "cpu" - ) - another_embedding_backend = ( - _InstructorEmbeddingBackendFactory.get_embedding_backend( - model="hkunlp/instructor-base", device="cpu" - ) + same_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend("hkunlp/instructor-large", "cpu") + another_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( + model="hkunlp/instructor-base", device="cpu" ) assert same_embedding_backend is embedding_backend @@ -34,14 +30,10 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 ) def test_model_initialization(mock_instructor): _InstructorEmbeddingBackendFactory.get_embedding_backend( - model="hkunlp/instructor-base", - device="cpu", - token=Secret.from_token("fake-api-token"), + model="hkunlp/instructor-base", device="cpu", token=Secret.from_token("fake-api-token") ) mock_instructor.assert_called_once_with( - model_name_or_path="hkunlp/instructor-base", - device="cpu", - use_auth_token="fake-api-token", + model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="fake-api-token" ) # restore the factory state _InstructorEmbeddingBackendFactory._instances = {} @@ -51,15 +43,11 @@ def test_model_initialization(mock_instructor): "haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR" ) def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 - embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( - model="hkunlp/instructor-base" - ) + embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(model="hkunlp/instructor-base") data = [["instruction", "sentence1"], ["instruction", "sentence2"]] embedding_backend.embed(data=data, normalize_embeddings=True) - embedding_backend.model.encode.assert_called_once_with( - data, normalize_embeddings=True - ) + embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True) # restore the factory state _InstructorEmbeddingBackendFactory._instances = {} diff --git a/integrations/instructor_embedders/tests/test_instructor_document_embedder.py b/integrations/instructor_embedders/tests/test_instructor_document_embedder.py index fa3ed1cfb..44f740679 100644 --- a/integrations/instructor_embedders/tests/test_instructor_document_embedder.py +++ b/integrations/instructor_embedders/tests/test_instructor_document_embedder.py @@ -4,9 +4,7 @@ import pytest from haystack import Document from haystack.utils import ComponentDevice, Secret -from haystack_integrations.components.embedders.instructor_embedders import ( - InstructorDocumentEmbedder, -) +from haystack_integrations.components.embedders.instructor_embedders import InstructorDocumentEmbedder class TestInstructorDocumentEmbedder: @@ -43,10 +41,7 @@ def test_init_with_parameters(self): assert embedder.model == "hkunlp/instructor-base" assert embedder.device == ComponentDevice.from_str("cuda:0") assert embedder.token == Secret.from_token("fake-api-token") - assert ( - embedder.instruction - == "Represent the 'domain' 'text_type' for 'task_objective'" - ) + assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True @@ -57,20 +52,14 @@ def test_to_dict(self): """ Test serialization of InstructorDocumentEmbedder to a dictionary, using default initialization parameters. """ - embedder = InstructorDocumentEmbedder( - model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu") - ) + embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu")) embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cpu").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the document", "batch_size": 32, "progress_bar": True, @@ -100,11 +89,7 @@ def test_to_dict_with_custom_init_parameters(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cuda:0").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the financial document for retrieval", "batch_size": 64, "progress_bar": False, @@ -123,11 +108,7 @@ def test_from_dict(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cpu").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the 'domain' 'text_type' for 'task_objective'", "batch_size": 32, "progress_bar": True, @@ -140,10 +121,7 @@ def test_from_dict(self): assert embedder.model == "hkunlp/instructor-base" assert embedder.device == ComponentDevice.from_str("cpu") assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert ( - embedder.instruction - == "Represent the 'domain' 'text_type' for 'task_objective'" - ) + assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'" assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False @@ -159,11 +137,7 @@ def test_from_dict_with_custom_init_parameters(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cuda:0").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the financial document for retrieval", "batch_size": 64, "progress_bar": False, @@ -190,9 +164,7 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = InstructorDocumentEmbedder( - model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu") - ) + embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu")) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( @@ -220,9 +192,7 @@ def test_embed(self): """ embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-large") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( - len(x), 16 - ).tolist() # noqa: ARG005 + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 documents = [Document(content=f"Sample-document text {i}") for i in range(5)] @@ -244,16 +214,10 @@ def test_embed_incorrect_input_format(self): string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, - match="InstructorDocumentEmbedder expects a list of Documents as input.", - ): + with pytest.raises(TypeError, match="InstructorDocumentEmbedder expects a list of Documents as input."): embedder.run(documents=string_input) - with pytest.raises( - TypeError, - match="InstructorDocumentEmbedder expects a list of Documents as input.", - ): + with pytest.raises(TypeError, match="InstructorDocumentEmbedder expects a list of Documents as input."): embedder.run(documents=list_integers_input) def test_embed_metadata(self): @@ -269,37 +233,17 @@ def test_embed_metadata(self): ) embedder.embedding_backend = MagicMock() - documents = [ - Document( - content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"} - ) - for i in range(5) - ] + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] embedder.run(documents=documents) embedder.embedding_backend.embed.assert_called_once_with( [ - [ - "Represent the financial document for retrieval", - "meta_value 0\ndocument-number 0", - ], - [ - "Represent the financial document for retrieval", - "meta_value 1\ndocument-number 1", - ], - [ - "Represent the financial document for retrieval", - "meta_value 2\ndocument-number 2", - ], - [ - "Represent the financial document for retrieval", - "meta_value 3\ndocument-number 3", - ], - [ - "Represent the financial document for retrieval", - "meta_value 4\ndocument-number 4", - ], + ["Represent the financial document for retrieval", "meta_value 0\ndocument-number 0"], + ["Represent the financial document for retrieval", "meta_value 1\ndocument-number 1"], + ["Represent the financial document for retrieval", "meta_value 2\ndocument-number 2"], + ["Represent the financial document for retrieval", "meta_value 3\ndocument-number 3"], + ["Represent the financial document for retrieval", "meta_value 4\ndocument-number 4"], ], batch_size=32, show_progress_bar=True, diff --git a/integrations/instructor_embedders/tests/test_instructor_text_embedder.py b/integrations/instructor_embedders/tests/test_instructor_text_embedder.py index a1e0b443b..55022f1ec 100644 --- a/integrations/instructor_embedders/tests/test_instructor_text_embedder.py +++ b/integrations/instructor_embedders/tests/test_instructor_text_embedder.py @@ -3,9 +3,7 @@ import numpy as np import pytest from haystack.utils import ComponentDevice, Secret -from haystack_integrations.components.embedders.instructor_embedders import ( - InstructorTextEmbedder, -) +from haystack_integrations.components.embedders.instructor_embedders import InstructorTextEmbedder class TestInstructorTextEmbedder: @@ -38,10 +36,7 @@ def test_init_with_parameters(self): assert embedder.model == "hkunlp/instructor-base" assert embedder.device == ComponentDevice.from_str("cuda:0") assert embedder.token == Secret.from_token("fake-api-token") - assert ( - embedder.instruction - == "Represent the 'domain' 'text_type' for 'task_objective'" - ) + assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True @@ -50,20 +45,14 @@ def test_to_dict(self): """ Test serialization of InstructorTextEmbedder to a dictionary, using default initialization parameters. """ - embedder = InstructorTextEmbedder( - model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu") - ) + embedder = InstructorTextEmbedder(model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu")) embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cpu").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the sentence", "batch_size": 32, "progress_bar": True, @@ -89,11 +78,7 @@ def test_to_dict_with_custom_init_parameters(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cuda:0").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the financial document for retrieval", "batch_size": 64, "progress_bar": False, @@ -110,11 +95,7 @@ def test_from_dict(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cpu").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the 'domain' 'text_type' for 'task_objective'", "batch_size": 32, "progress_bar": True, @@ -125,10 +106,7 @@ def test_from_dict(self): assert embedder.model == "hkunlp/instructor-base" assert embedder.device == ComponentDevice.from_str("cpu") assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert ( - embedder.instruction - == "Represent the 'domain' 'text_type' for 'task_objective'" - ) + assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'" assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False @@ -142,11 +120,7 @@ def test_from_dict_with_custom_init_parameters(self): "init_parameters": { "model": "hkunlp/instructor-base", "device": ComponentDevice.from_str("cuda:0").to_dict(), - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "instruction": "Represent the financial document for retrieval", "batch_size": 64, "progress_bar": False, @@ -169,9 +143,7 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = InstructorTextEmbedder( - model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu") - ) + embedder = InstructorTextEmbedder(model="hkunlp/instructor-base", device=ComponentDevice.from_str("cpu")) mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( @@ -199,9 +171,7 @@ def test_embed(self): """ embedder = InstructorTextEmbedder(model="hkunlp/instructor-large") embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand( - len(x), 16 - ).tolist() # noqa: ARG005 + embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005 text = "Good text to embed" @@ -220,9 +190,7 @@ def test_run_wrong_incorrect_format(self): list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="InstructorTextEmbedder expects a string as input" - ): + with pytest.raises(TypeError, match="InstructorTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) @pytest.mark.integration diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 1fa3fc1f4..6bcd94220 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -122,24 +122,16 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( - self.prefix - + self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) - + self.suffix + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix ) texts_to_embed.append(text_to_embed) return texts_to_embed - def _embed_batch( - self, texts_to_embed: List[str], batch_size: int - ) -> Tuple[List[List[float]], Dict[str, Any]]: + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: """ Embed a list of texts in batches. """ @@ -147,14 +139,10 @@ def _embed_batch( all_embeddings = [] metadata = {} for i in tqdm( - range(0, len(texts_to_embed), batch_size), - disable=not self.progress_bar, - desc="Calculating embeddings", + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - response = self._session.post( - JINA_API_URL, json={"input": batch, "model": self.model_name} - ).json() + response = self._session.post(JINA_API_URL, json={"input": batch, "model": self.model_name}).json() if "data" not in response: raise RuntimeError(response["detail"]) @@ -183,11 +171,7 @@ def run(self, documents: List[Document]): - `meta`: A dictionary with metadata including the model name and usage statistics. :raises TypeError: If the input is not a list of Documents. """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "JinaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the JinaTextEmbedder." @@ -196,9 +180,7 @@ def run(self, documents: List[Document]): texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, metadata = self._embed_batch( - texts_to_embed=texts_to_embed, batch_size=self.batch_size - ) + embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py index f604cf3ea..6398122a4 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py @@ -79,11 +79,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ return default_to_dict( - self, - api_key=self.api_key.to_dict(), - model=self.model_name, - prefix=self.prefix, - suffix=self.suffix, + self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix ) @classmethod @@ -118,9 +114,7 @@ def run(self, text: str): text_to_embed = self.prefix + text + self.suffix - resp = self._session.post( - JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name} - ).json() + resp = self._session.post(JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name}).json() if "data" not in resp: raise RuntimeError(resp["detail"]) diff --git a/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py b/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py index 8ae986235..97dace746 100644 --- a/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py +++ b/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py @@ -170,7 +170,4 @@ def run( else: break - return { - "documents": ranked_docs, - "meta": {"model": resp["model"], "usage": resp["usage"]}, - } + return {"documents": ranked_docs, "meta": {"model": resp["model"], "usage": resp["usage"]}} diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index 87b002a17..a9ba23ec0 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -16,17 +16,9 @@ def mock_session_post_response(*args, **kwargs): # noqa: ARG001 model = kwargs["json"]["model"] mock_response = requests.Response() mock_response.status_code = 200 - data = [ - {"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]} - for i in range(len(inputs)) - ] + data = [{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] mock_response._content = json.dumps( - { - "model": model, - "object": "list", - "usage": {"total_tokens": 4, "prompt_tokens": 4}, - "data": data, - } + {"model": model, "object": "list", "usage": {"total_tokens": 4, "prompt_tokens": 4}, "data": data} ).encode() return mock_response @@ -79,11 +71,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "jina-embeddings-v2-base-en", "prefix": "", "suffix": "", @@ -109,11 +97,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "model", "prefix": "prefix", "suffix": "suffix", @@ -126,17 +110,11 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): def test_prepare_texts_to_embed_w_metadata(self): documents = [ - Document( - content=f"document number {i}:\ncontent", - meta={"meta_field": f"meta_value {i}"}, - ) - for i in range(5) + Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] embedder = JinaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), - meta_fields_to_embed=["meta_field"], - embedding_separator=" | ", + api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | " ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -154,9 +132,7 @@ def test_prepare_texts_to_embed_w_suffix(self): documents = [Document(content=f"document number {i}") for i in range(5)] embedder = JinaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), - prefix="my_prefix ", - suffix=" my_suffix", + api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix" ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -172,16 +148,10 @@ def test_prepare_texts_to_embed_w_suffix(self): def test_embed_batch(self): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - with patch( - "requests.sessions.Session.post", side_effect=mock_session_post_response - ): - embedder = JinaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), model="model" - ) + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key"), model="model") - embeddings, metadata = embedder._embed_batch( - texts_to_embed=texts, batch_size=2 - ) + embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2) assert isinstance(embeddings, list) assert len(embeddings) == len(texts) @@ -190,24 +160,16 @@ def test_embed_batch(self): assert len(embedding) == 3 assert all(isinstance(x, float) for x in embedding) - assert metadata == { - "model": "model", - "usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}, - } + assert metadata == {"model": "model", "usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}} def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "jina-embeddings-v2-base-en" - with patch( - "requests.sessions.Session.post", side_effect=mock_session_post_response - ): + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): embedder = JinaDocumentEmbedder( api_key=Secret.from_token("fake-api-key"), model=model, @@ -229,23 +191,15 @@ def test_run(self): assert isinstance(doc.embedding, list) assert len(doc.embedding) == 3 assert all(isinstance(x, float) for x in doc.embedding) - assert metadata == { - "model": model, - "usage": {"prompt_tokens": 4, "total_tokens": 4}, - } + assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} def test_run_custom_batch_size(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "jina-embeddings-v2-base-en" - with patch( - "requests.sessions.Session.post", side_effect=mock_session_post_response - ): + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): embedder = JinaDocumentEmbedder( api_key=Secret.from_token("fake-api-key"), model=model, @@ -269,10 +223,7 @@ def test_run_custom_batch_size(self): assert len(doc.embedding) == 3 assert all(isinstance(x, float) for x in doc.embedding) - assert metadata == { - "model": model, - "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}, - } + assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} def test_run_wrong_input_format(self): embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) @@ -280,14 +231,10 @@ def test_run_wrong_input_format(self): string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="JinaDocumentEmbedder expects a list of Documents as input" - ): + with pytest.raises(TypeError, match="JinaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=string_input) - with pytest.raises( - TypeError, match="JinaDocumentEmbedder expects a list of Documents as input" - ): + with pytest.raises(TypeError, match="JinaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): diff --git a/integrations/jina/tests/test_ranker.py b/integrations/jina/tests/test_ranker.py index 5c5a18fc3..7ece6806d 100644 --- a/integrations/jina/tests/test_ranker.py +++ b/integrations/jina/tests/test_ranker.py @@ -21,11 +21,7 @@ def mock_session_post_response(*args, **kwargs): # noqa: ARG001 for i, doc in enumerate(documents) ] mock_response._content = json.dumps( - { - "model": model, - "usage": {"total_tokens": 4, "prompt_tokens": 4}, - "results": results, - } + {"model": model, "usage": {"total_tokens": 4, "prompt_tokens": 4}, "results": results} ).encode() return mock_response @@ -40,12 +36,7 @@ def test_init_default(self, monkeypatch): assert embedder.model == "jina-reranker-v1-base-en" def test_init_with_parameters(self): - embedder = JinaRanker( - api_key=Secret.from_token("fake-api-key"), - model="model", - top_k=64, - score_threshold=0.5, - ) + embedder = JinaRanker(api_key=Secret.from_token("fake-api-key"), model="model", top_k=64, score_threshold=0.5) assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model == "model" @@ -64,11 +55,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.rankers.jina.ranker.JinaRanker", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "jina-reranker-v1-base-en", "top_k": None, "score_threshold": None, @@ -82,11 +69,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.rankers.jina.ranker.JinaRanker", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "model", "top_k": 64, "score_threshold": 0.5, @@ -103,9 +86,7 @@ def test_run(self): query = "What is a transformer?" model = "jina-ranker" - with patch( - "requests.sessions.Session.post", side_effect=mock_session_post_response - ): + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): ranker = JinaRanker( api_key=Secret.from_token("fake-api-key"), model=model, @@ -121,10 +102,7 @@ def test_run(self): for i, doc in enumerate(ranked_documents): assert isinstance(doc, Document) assert doc.score == len(ranked_documents) - i - assert metadata == { - "model": model, - "usage": {"prompt_tokens": 4, "total_tokens": 4}, - } + assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} def test_run_wrong_input_format(self): ranker = JinaRanker(api_key=Secret.from_token("fake-api-key")) diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index be78a6edc..7cb669c68 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -44,11 +44,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "jina-embeddings-v2-base-en", "prefix": "", "suffix": "", @@ -66,11 +62,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["JINA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"}, "model": "model", "prefix": "prefix", "suffix": "suffix", @@ -88,23 +80,14 @@ def test_run(self): "model": "jina-embeddings-v2-base-en", "object": "list", "usage": {"total_tokens": 6, "prompt_tokens": 6}, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [0.1, 0.2, 0.3], - } - ], + "data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}], } ).encode() mock_post.return_value = mock_response embedder = JinaTextEmbedder( - api_key=Secret.from_token("fake-api-key"), - model=model, - prefix="prefix ", - suffix=" suffix", + api_key=Secret.from_token("fake-api-key"), model=model, prefix="prefix ", suffix=" suffix" ) result = embedder.run(text="The food was delicious") @@ -120,7 +103,5 @@ def test_run_wrong_input_format(self): list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="JinaTextEmbedder expects a string as an input" - ): + with pytest.raises(TypeError, match="JinaTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) diff --git a/integrations/langfuse/example/basic_rag.py b/integrations/langfuse/example/basic_rag.py index 37ab3574a..492a14d49 100644 --- a/integrations/langfuse/example/basic_rag.py +++ b/integrations/langfuse/example/basic_rag.py @@ -6,10 +6,7 @@ from datasets import load_dataset from haystack import Document, Pipeline from haystack.components.builders import PromptBuilder -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator from haystack.components.retrievers import InMemoryEmbeddingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore @@ -37,16 +34,11 @@ def get_pipeline(document_store: InMemoryDocumentStore): # Add components to your pipeline basic_rag_pipeline.add_component("tracer", LangfuseConnector("Basic RAG Pipeline")) basic_rag_pipeline.add_component( - "text_embedder", - SentenceTransformersTextEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" - ), + "text_embedder", SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") ) basic_rag_pipeline.add_component("retriever", retriever) basic_rag_pipeline.add_component("prompt_builder", prompt_builder) - basic_rag_pipeline.add_component( - "llm", OpenAIGenerator(model="gpt-3.5-turbo", generation_kwargs={"n": 2}) - ) + basic_rag_pipeline.add_component("llm", OpenAIGenerator(model="gpt-3.5-turbo", generation_kwargs={"n": 2})) # Now, connect the components to each other # NOTE: the tracer component doesn't need to be connected to anything in order to work @@ -60,18 +52,14 @@ def get_pipeline(document_store: InMemoryDocumentStore): if __name__ == "__main__": document_store = InMemoryDocumentStore() dataset = load_dataset("bilgeyucel/seven-wonders", split="train") - embedder = SentenceTransformersDocumentEmbedder( - "sentence-transformers/all-MiniLM-L6-v2" - ) + embedder = SentenceTransformersDocumentEmbedder("sentence-transformers/all-MiniLM-L6-v2") embedder.warm_up() docs_with_embeddings = embedder.run([Document(**ds) for ds in dataset]).get("documents") or [] # type: ignore document_store.write_documents(docs_with_embeddings) pipeline = get_pipeline(document_store) question = "What does Rhodes Statue look like?" - response = pipeline.run( - {"text_embedder": {"text": question}, "prompt_builder": {"question": question}} - ) + response = pipeline.run({"text_embedder": {"text": question}, "prompt_builder": {"question": question}}) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 637cfbc73..443d65a13 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -9,6 +9,7 @@ from haystack_integrations.components.connectors.langfuse import LangfuseConnector if __name__ == "__main__": + pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) pipe.add_component("prompt_builder", ChatPromptBuilder()) @@ -17,19 +18,10 @@ pipe.connect("prompt_builder.prompt", "llm.messages") messages = [ - ChatMessage.from_system( - "Always respond in German even if some input data is in other languages." - ), + ChatMessage.from_system("Always respond in German even if some input data is in other languages."), ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run( - data={ - "prompt_builder": { - "template_variables": {"location": "Berlin"}, - "template": messages, - } - } - ) + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index a8777aa64..4bf0da2f8 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -19,10 +19,7 @@ class LangfuseSpan(Span): Internal class representing a bridge between the Haystack span tracing API and Langfuse. """ - def __init__( - self, - span: "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]", - ) -> None: + def __init__(self, span: "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]") -> None: """ Initialize a LangfuseSpan instance. @@ -87,9 +84,7 @@ class LangfuseTracer(Tracer): Internal class representing a bridge between the Haystack tracer and Langfuse. """ - def __init__( - self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: bool = False - ) -> None: + def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: bool = False) -> None: """ Initialize a LangfuseTracer instance. @@ -104,14 +99,10 @@ def __init__( self._context: list[LangfuseSpan] = [] self._name = name self._public = public - self.enforce_flush = ( - os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" - ) + self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" @contextlib.contextmanager - def trace( - self, operation_name: str, tags: Optional[Dict[str, Any]] = None - ) -> Iterator[Span]: + def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: """ Start and manage a new trace span. :param operation_name: The name of the operation. @@ -122,9 +113,7 @@ def trace( span_name = tags.get("haystack.component.name", operation_name) if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan( - self.current_span().raw_span().generation(name=span_name) - ) + span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) else: span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) @@ -144,9 +133,7 @@ def trace( replies = span._data.get("haystack.component.output", {}).get("replies") if replies: meta = replies[0].meta - span._span.update( - usage=meta.get("usage") or None, model=meta.get("model") - ) + span._span.update(usage=meta.get("usage") or None, model=meta.get("model")) pipeline_input = tags.get("haystack.pipeline.input_data", None) if pipeline_input: @@ -176,9 +163,7 @@ def current_span(self) -> Span: """ if not self._context: # The root span has to be a trace - self._context.append( - LangfuseSpan(self._tracer.trace(name=self._name, public=self._public)) - ) + self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) return self._context[-1] def get_trace_url(self) -> str: diff --git a/integrations/langfuse/tests/test_langfuse_span.py b/integrations/langfuse/tests/test_langfuse_span.py index e6f0d5e35..a5a5f2c13 100644 --- a/integrations/langfuse/tests/test_langfuse_span.py +++ b/integrations/langfuse/tests/test_langfuse_span.py @@ -8,6 +8,7 @@ class TestLangfuseSpan: + # LangfuseSpan can be initialized with a span object def test_initialized_with_span_object(self): mock_span = Mock() @@ -41,31 +42,19 @@ def test_set_content_tag_updates_input_and_output_with_messages(self): # test message input span = LangfuseSpan(mock_span) - span.set_content_tag( - "key.input", {"messages": [ChatMessage.from_user("message")]} - ) + span.set_content_tag("key.input", {"messages": [ChatMessage.from_user("message")]}) assert mock_span.update.call_count == 1 # check we converted ChatMessage to OpenAI format - assert mock_span.update.call_args_list[0][1] == { - "input": [{"role": "user", "content": "message"}] - } - assert span._data["key.input"] == { - "messages": [ChatMessage.from_user("message")] - } + assert mock_span.update.call_args_list[0][1] == {"input": [{"role": "user", "content": "message"}]} + assert span._data["key.input"] == {"messages": [ChatMessage.from_user("message")]} # test replies ChatMessage list mock_span.reset_mock() - span.set_content_tag( - "key.output", {"replies": [ChatMessage.from_system("reply")]} - ) + span.set_content_tag("key.output", {"replies": [ChatMessage.from_system("reply")]}) assert mock_span.update.call_count == 1 # check we converted ChatMessage to OpenAI format - assert mock_span.update.call_args_list[0][1] == { - "output": [{"role": "system", "content": "reply"}] - } - assert span._data["key.output"] == { - "replies": [ChatMessage.from_system("reply")] - } + assert mock_span.update.call_args_list[0][1] == {"output": [{"role": "system", "content": "reply"}]} + assert span._data["key.output"] == {"replies": [ChatMessage.from_system("reply")]} # test replies string list mock_span.reset_mock() diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index cd464d56b..241581a72 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -5,6 +5,7 @@ class TestLangfuseTracer: + # LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public. def test_initialization(self): langfuse_instance = Mock() @@ -21,9 +22,7 @@ def test_create_new_span(self): mock_raw_span.operation_name = "operation_name" mock_raw_span.metadata = {"tag1": "value1", "tag2": "value2"} - with patch( - "haystack_integrations.tracing.langfuse.tracer.LangfuseSpan" - ) as MockLangfuseSpan: + with patch("haystack_integrations.tracing.langfuse.tracer.LangfuseSpan") as MockLangfuseSpan: mock_span_instance = MockLangfuseSpan.return_value mock_span_instance.raw_span.return_value = mock_raw_span @@ -35,12 +34,8 @@ def test_create_new_span(self): tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) - with tracer.trace( - "operation_name", tags={"tag1": "value1", "tag2": "value2"} - ) as span: - assert ( - len(tracer._context) == 2 - ), "The trace span should have been added to the the root context span" + with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: + assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" assert span.raw_span().operation_name == "operation_name" assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} @@ -51,6 +46,7 @@ def test_create_new_span(self): # check that update method is called on the span instance with the provided key value pairs def test_update_span_with_pipeline_input_output_data(self): class MockTracer: + def trace(self, name, **kwargs): return MockSpan() @@ -81,30 +77,17 @@ def end(self): pass tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) - with tracer.trace( - operation_name="operation_name", - tags={"haystack.pipeline.input_data": "hello"}, - ) as span: - assert span.raw_span()._data["metadata"] == { - "haystack.pipeline.input_data": "hello" - } - - with tracer.trace( - operation_name="operation_name", - tags={"haystack.pipeline.output_data": "bye"}, - ) as span: - assert span.raw_span()._data["metadata"] == { - "haystack.pipeline.output_data": "bye" - } + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: + assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"} + + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span: + assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"} def test_update_span_gets_flushed_by_default(self): tracer_mock = Mock() tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) - with tracer.trace( - operation_name="operation_name", - tags={"haystack.pipeline.input_data": "hello"}, - ) as span: + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: pass tracer_mock.flush.assert_called_once() @@ -116,10 +99,7 @@ def test_update_span_flush_disable(self, monkeypatch): from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) - with tracer.trace( - operation_name="operation_name", - tags={"haystack.pipeline.input_data": "hello"}, - ) as span: + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: pass tracer_mock.flush.assert_not_called() @@ -128,10 +108,7 @@ def test_context_is_empty_after_tracing(self): tracer_mock = Mock() tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) - with tracer.trace( - operation_name="operation_name", - tags={"haystack.pipeline.input_data": "hello"}, - ) as span: + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: pass assert tracer._context == [] diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index ee190bc2c..111d89dfd 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -19,35 +19,24 @@ @pytest.mark.integration @pytest.mark.skipif( - not os.environ.get("LANGFUSE_SECRET_KEY", None) - and not os.environ.get("LANGFUSE_PUBLIC_KEY", None), + not os.environ.get("LANGFUSE_SECRET_KEY", None) and not os.environ.get("LANGFUSE_PUBLIC_KEY", None), reason="Export an env var called LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY containing Langfuse credentials.", ) def test_tracing_integration(): + pipe = Pipeline() - pipe.add_component( - "tracer", LangfuseConnector(name="Chat example", public=True) - ) # public so anyone can verify run + pipe.add_component("tracer", LangfuseConnector(name="Chat example", public=True)) # public so anyone can verify run pipe.add_component("prompt_builder", ChatPromptBuilder()) pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) pipe.connect("prompt_builder.prompt", "llm.messages") messages = [ - ChatMessage.from_system( - "Always respond in German even if some input data is in other languages." - ), + ChatMessage.from_system("Always respond in German even if some input data is in other languages."), ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run( - data={ - "prompt_builder": { - "template_variables": {"location": "Berlin"}, - "template": messages, - } - } - ) + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] url = "https://cloud.langfuse.com/api/public/traces/" @@ -58,15 +47,9 @@ def test_tracing_integration(): try: # GET request with Basic Authentication on the Langfuse API response = requests.get( - url + uuid, - auth=HTTPBasicAuth( - os.environ.get("LANGFUSE_PUBLIC_KEY"), - os.environ.get("LANGFUSE_SECRET_KEY"), - ), + url + uuid, auth=HTTPBasicAuth(os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get("LANGFUSE_SECRET_KEY")) ) - assert ( - response.status_code == 200 - ), f"Failed to retrieve data from Langfuse API: {response.status_code}" + assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" except requests.exceptions.RequestException as e: assert False, f"Failed to retrieve data from Langfuse API: {e}" diff --git a/integrations/llama_cpp/examples/llama_cpp_generator_example.py b/integrations/llama_cpp/examples/llama_cpp_generator_example.py index fe1c56c44..96f8aec1d 100644 --- a/integrations/llama_cpp/examples/llama_cpp_generator_example.py +++ b/integrations/llama_cpp/examples/llama_cpp_generator_example.py @@ -1,8 +1,6 @@ from haystack_integrations.components.generators.llama_cpp import LlamaCppGenerator -generator = LlamaCppGenerator( - model="openchat-3.5-1210.Q3_K_S.gguf", n_ctx=512, n_batch=128 -) +generator = LlamaCppGenerator(model="openchat-3.5-1210.Q3_K_S.gguf", n_ctx=512, n_batch=128) generator.warm_up() question = "Who is the best American actor?" diff --git a/integrations/llama_cpp/examples/rag_pipeline_example.py b/integrations/llama_cpp/examples/rag_pipeline_example.py index 626034b0a..dcab6dbd6 100644 --- a/integrations/llama_cpp/examples/rag_pipeline_example.py +++ b/integrations/llama_cpp/examples/rag_pipeline_example.py @@ -2,10 +2,7 @@ from haystack import Document, Pipeline from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.retrievers import InMemoryEmbeddingRetriever from haystack.components.writers import DocumentWriter from haystack.document_stores import InMemoryDocumentStore @@ -26,17 +23,13 @@ ] doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine") -doc_embedder = SentenceTransformersDocumentEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" -) +doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=doc_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=doc_store), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=doc_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": docs}}) @@ -54,9 +47,7 @@ """ rag_pipeline = Pipeline() -text_embedder = SentenceTransformersTextEmbedder( - model="sentence-transformers/all-MiniLM-L6-v2" -) +text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") model_path = "openchat-3.5-1210.Q3_K_S.gguf" generator = LlamaCppGenerator(model=model_path, n_ctx=4096, n_batch=128) @@ -65,13 +56,8 @@ instance=text_embedder, name="text_embedder", ) -rag_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=doc_store, top_k=3), - name="retriever", -) -rag_pipeline.add_component( - instance=PromptBuilder(template=prompt_template), name="prompt_builder" -) +rag_pipeline.add_component(instance=InMemoryEmbeddingRetriever(document_store=doc_store, top_k=3), name="retriever") +rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") rag_pipeline.add_component(instance=generator, name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index 39a5028b5..d43700215 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -70,9 +70,7 @@ def __init__( generation_kwargs = generation_kwargs or {} if "hf_tokenizer_path" in model_kwargs: - tokenizer = LlamaHFTokenizer.from_pretrained( - model_kwargs["hf_tokenizer_path"] - ) + tokenizer = LlamaHFTokenizer.from_pretrained(model_kwargs["hf_tokenizer_path"]) model_kwargs["tokenizer"] = tokenizer # check if the model_kwargs contain the essential parameters @@ -93,11 +91,7 @@ def warm_up(self): self.model = Llama(**self.model_kwargs) @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - generation_kwargs: Optional[Dict[str, Any]] = None, - ): + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): """ Run the text generation model on the given list of ChatMessages. @@ -110,25 +104,16 @@ def run( - `replies`: The responses from the model """ if self.model is None: - error_msg = ( - "The model has not been loaded. Please call warm_up() before running." - ) + error_msg = "The model has not been loaded. Please call warm_up() before running." raise RuntimeError(error_msg) if not messages: return {"replies": []} - updated_generation_kwargs = { - **self.generation_kwargs, - **(generation_kwargs or {}), - } - formatted_messages = [ - _convert_message_to_llamacpp_format(msg) for msg in messages - ] + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] - response = self.model.create_chat_completion( - messages=formatted_messages, **updated_generation_kwargs - ) + response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) replies = [ ChatMessage( content=choice["message"]["content"], diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py index b28cc446b..1c504b6f3 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py @@ -83,23 +83,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - `meta`: metadata about the request. """ if self.model is None: - error_msg = ( - "The model has not been loaded. Please call warm_up() before running." - ) + error_msg = "The model has not been loaded. Please call warm_up() before running." raise RuntimeError(error_msg) if not prompt: return {"replies": []} # merge generation kwargs from init method with those from run method - updated_generation_kwargs = { - **self.generation_kwargs, - **(generation_kwargs or {}), - } - - output = self.model.create_completion( - prompt=prompt, **updated_generation_kwargs - ) + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + output = self.model.create_completion(prompt=prompt, **updated_generation_kwargs) replies = [output["choices"][0]["text"]] return {"replies": replies, "meta": [output]} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 0c4697cef..7bd6ef122 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -34,16 +34,10 @@ def download_file(file_link, filename, capsys): def test_convert_message_to_llamacpp_format(): message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_llamacpp_format(message) == { - "role": "system", - "content": "You are good assistant", - } + assert _convert_message_to_llamacpp_format(message) == {"role": "system", "content": "You are good assistant"} message = ChatMessage.from_user("I have a question") - assert _convert_message_to_llamacpp_format(message) == { - "role": "user", - "content": "I have a question", - } + assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} message = ChatMessage.from_function("Function call", "function_name") assert _convert_message_to_llamacpp_format(message) == { @@ -56,7 +50,9 @@ def test_convert_message_to_llamacpp_format(): class TestLlamaCppChatGenerator: @pytest.fixture def generator(self, model_path, capsys): - gguf_model_path = "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) filename = "openchat-3.5-1210.Q3_K_S.gguf" # Download GGUF model from HuggingFace @@ -70,9 +66,7 @@ def generator(self, model_path, capsys): @pytest.fixture def generator_mock(self): mock_model = MagicMock() - generator = LlamaCppChatGenerator( - model="test_model.gguf", n_ctx=2048, n_batch=512 - ) + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512) generator.model = mock_model return generator, mock_model @@ -85,11 +79,7 @@ def test_default_init(self): assert generator.model_path == "test_model.gguf" assert generator.n_ctx == 0 assert generator.n_batch == 512 - assert generator.model_kwargs == { - "model_path": "test_model.gguf", - "n_ctx": 0, - "n_batch": 512, - } + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 0, "n_batch": 512} assert generator.generation_kwargs == {} def test_custom_init(self): @@ -105,11 +95,7 @@ def test_custom_init(self): assert generator.model_path == "test_model.gguf" assert generator.n_ctx == 8192 assert generator.n_batch == 512 - assert generator.model_kwargs == { - "model_path": "test_model.gguf", - "n_ctx": 8192, - "n_batch": 512, - } + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 8192, "n_batch": 512} assert generator.generation_kwargs == {} def test_ignores_model_path_if_specified_in_model_kwargs(self): @@ -128,12 +114,7 @@ def test_ignores_n_ctx_if_specified_in_model_kwargs(self): """ Test that n_ctx is ignored if already specified in model_kwargs. """ - generator = LlamaCppChatGenerator( - model="test_model.gguf", - n_ctx=512, - n_batch=512, - model_kwargs={"n_ctx": 8192}, - ) + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 8192}) assert generator.model_kwargs["n_ctx"] == 8192 def test_ignores_n_batch_if_specified_in_model_kwargs(self): @@ -141,10 +122,7 @@ def test_ignores_n_batch_if_specified_in_model_kwargs(self): Test that n_batch is ignored if already specified in model_kwargs. """ generator = LlamaCppChatGenerator( - model="test_model.gguf", - n_ctx=8192, - n_batch=512, - model_kwargs={"n_batch": 1024}, + model="test_model.gguf", n_ctx=8192, n_batch=512, model_kwargs={"n_batch": 1024} ) assert generator.model_kwargs["n_batch"] == 1024 @@ -152,9 +130,7 @@ def test_raises_error_without_warm_up(self): """ Test that the generator raises an error if warm_up() is not called before running. """ - generator = LlamaCppChatGenerator( - model="test_model.gguf", n_ctx=512, n_batch=512 - ) + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512) with pytest.raises(RuntimeError): generator.run("What is the capital of China?") @@ -177,11 +153,7 @@ def test_run_with_valid_message(self, generator_mock): "model": "Test Model Path", "created": 1715226164, "choices": [ - { - "index": 0, - "message": {"content": "Generated text", "role": "assistant"}, - "finish_reason": "stop", - } + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "stop"} ], "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, } @@ -203,19 +175,13 @@ def test_run_with_generation_kwargs(self, generator_mock): "model": "Test Model Path", "created": 1715226164, "choices": [ - { - "index": 0, - "message": {"content": "Generated text", "role": "assistant"}, - "finish_reason": "length", - } + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "length"} ], "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, } mock_model.create_chat_completion.return_value = mock_output generation_kwargs = {"max_tokens": 128} - result = generator.run( - [ChatMessage.from_system("Write a 200 word paragraph.")], generation_kwargs - ) + result = generator.run([ChatMessage.from_system("Write a 200 word paragraph.")], generation_kwargs) assert result["replies"][0].content == "Generated text" assert result["replies"][0].meta["finish_reason"] == "length" @@ -239,9 +205,7 @@ def test_run(self, generator): assert "replies" in result assert isinstance(result["replies"], list) assert len(result["replies"]) > 0 - assert any( - answer.lower() in reply.content.lower() for reply in result["replies"] - ) + assert any(answer.lower() in reply.content.lower() for reply in result["replies"]) @pytest.mark.integration def test_run_rag_pipeline(self, generator): @@ -250,9 +214,7 @@ def test_run_rag_pipeline(self, generator): """ document_store = InMemoryDocumentStore() documents = [ - Document( - content="There are over 7,000 languages spoken around the world today." - ), + Document(content="There are over 7,000 languages spoken around the world today."), Document( content="""Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors.""" @@ -269,10 +231,7 @@ def test_run_rag_pipeline(self, generator): instance=InMemoryBM25Retriever(document_store=document_store, top_k=1), name="retriever", ) - pipeline.add_component( - instance=ChatPromptBuilder(variables=["query", "documents"]), - name="prompt_builder", - ) + pipeline.add_component(instance=ChatPromptBuilder(variables=["query", "documents"]), name="prompt_builder") pipeline.add_component(instance=generator, name="llm") pipeline.connect("retriever.documents", "prompt_builder.documents") pipeline.connect("prompt_builder.prompt", "llm.messages") @@ -318,11 +277,7 @@ def test_json_constraining(self, generator): """ Test that the generator can output valid JSON. """ - messages = [ - ChatMessage.from_system( - "Output valid json only. List 2 people with their name and age." - ) - ] + messages = [ChatMessage.from_system("Output valid json only. List 2 people with their name and age.")] json_schema = { "type": "object", "properties": { @@ -356,47 +311,30 @@ def test_json_constraining(self, generator): assert isinstance(json.loads(reply.content), dict) assert "people" in json.loads(reply.content) assert isinstance(json.loads(reply.content)["people"], list) - assert all( - isinstance(person, dict) - for person in json.loads(reply.content)["people"] - ) - assert all( - "name" in person for person in json.loads(reply.content)["people"] - ) - assert all( - "age" in person for person in json.loads(reply.content)["people"] - ) - assert all( - isinstance(person["name"], str) - for person in json.loads(reply.content)["people"] - ) - assert all( - isinstance(person["age"], int) - for person in json.loads(reply.content)["people"] - ) + assert all(isinstance(person, dict) for person in json.loads(reply.content)["people"]) + assert all("name" in person for person in json.loads(reply.content)["people"]) + assert all("age" in person for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["name"], str) for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["age"], int) for person in json.loads(reply.content)["people"]) class TestLlamaCppChatGeneratorFunctionary: def get_current_temperature(self, location): """Get the current temperature in a given location""" if "tokyo" in location.lower(): - return json.dumps( - {"location": "Tokyo", "temperature": "10", "unit": "celsius"} - ) + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) elif "san francisco" in location.lower(): - return json.dumps( - {"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"} - ) + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) elif "paris" in location.lower(): - return json.dumps( - {"location": "Paris", "temperature": "22", "unit": "celsius"} - ) + return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) else: return json.dumps({"location": location, "temperature": "unknown"}) @pytest.fixture def generator(self, model_path, capsys): - gguf_model_path = "https://huggingface.co/meetkai/functionary-small-v2.4-GGUF/resolve/main/functionary-small-v2.4.Q4_0.gguf" + gguf_model_path = ( + "https://huggingface.co/meetkai/functionary-small-v2.4-GGUF/resolve/main/functionary-small-v2.4.Q4_0.gguf" + ) filename = "functionary-small-v2.4.Q4_0.gguf" download_file(gguf_model_path, str(model_path / filename), capsys) model_path = str(model_path / filename) @@ -423,10 +361,7 @@ def test_function_call(self, generator): "parameters": { "type": "object", "properties": { - "username": { - "type": "string", - "description": "The username to retrieve information for.", - } + "username": {"type": "string", "description": "The username to retrieve information for."} }, "required": ["username"], }, @@ -439,10 +374,7 @@ def test_function_call(self, generator): messages = [ ChatMessage.from_user("Get information for user john_doe"), ] - response = generator.run( - messages=messages, - generation_kwargs={"tools": tools, "tool_choice": tool_choice}, - ) + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) assert "tool_calls" in response["replies"][0].meta tool_calls = response["replies"][0].meta["tool_calls"] @@ -466,10 +398,7 @@ def test_function_call_and_execute(self, generator): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, @@ -495,24 +424,23 @@ def test_function_call_and_execute(self, generator): function_args = json.loads(tool_call["function"]["arguments"]) assert function_name in available_functions function_response = available_functions[function_name](**function_args) - function_message = ChatMessage.from_function( - function_response, function_name - ) + function_message = ChatMessage.from_function(function_response, function_name) messages.append(function_message) second_response = generator.run(messages=messages) assert "replies" in second_response assert len(second_response["replies"]) > 0 - assert any( - "San Francisco" in reply.content for reply in second_response["replies"] - ) + assert any("San Francisco" in reply.content for reply in second_response["replies"]) assert any("72" in reply.content for reply in second_response["replies"]) class TestLlamaCppChatGeneratorChatML: + @pytest.fixture def generator(self, model_path, capsys): - gguf_model_path = "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) filename = "openchat-3.5-1210.Q3_K_S.gguf" download_file(gguf_model_path, str(model_path / filename), capsys) model_path = str(model_path / filename) @@ -558,10 +486,7 @@ def test_function_call_chatml(self, generator): tool_choice = {"type": "function", "function": {"name": "UserDetail"}} - response = generator.run( - messages=messages, - generation_kwargs={"tools": tools, "tool_choice": tool_choice}, - ) + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) for reply in response["replies"]: assert "tool_calls" in reply.meta tool_calls = reply.meta["tool_calls"] diff --git a/integrations/llama_cpp/tests/test_generator.py b/integrations/llama_cpp/tests/test_generator.py index 9249dd6a8..04b8339e5 100644 --- a/integrations/llama_cpp/tests/test_generator.py +++ b/integrations/llama_cpp/tests/test_generator.py @@ -31,7 +31,9 @@ def download_file(file_link, filename, capsys): class TestLlamaCppGenerator: @pytest.fixture def generator(self, model_path, capsys): - ggml_model_path = "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ggml_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) filename = "openchat-3.5-1210.Q3_K_S.gguf" # Download GGUF model from HuggingFace @@ -58,11 +60,7 @@ def test_default_init(self): assert generator.model_path == "test_model.gguf" assert generator.n_ctx == 0 assert generator.n_batch == 512 - assert generator.model_kwargs == { - "model_path": "test_model.gguf", - "n_ctx": 0, - "n_batch": 512, - } + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 0, "n_batch": 512} assert generator.generation_kwargs == {} def test_custom_init(self): @@ -78,11 +76,7 @@ def test_custom_init(self): assert generator.model_path == "test_model.gguf" assert generator.n_ctx == 2048 assert generator.n_batch == 512 - assert generator.model_kwargs == { - "model_path": "test_model.gguf", - "n_ctx": 2048, - "n_batch": 512, - } + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 2048, "n_batch": 512} assert generator.generation_kwargs == {} def test_ignores_model_path_if_specified_in_model_kwargs(self): @@ -101,24 +95,14 @@ def test_ignores_n_ctx_if_specified_in_model_kwargs(self): """ Test that n_ctx is ignored if already specified in model_kwargs. """ - generator = LlamaCppGenerator( - model="test_model.gguf", - n_ctx=512, - n_batch=512, - model_kwargs={"n_ctx": 1024}, - ) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024}) assert generator.model_kwargs["n_ctx"] == 1024 def test_ignores_n_batch_if_specified_in_model_kwargs(self): """ Test that n_batch is ignored if already specified in model_kwargs. """ - generator = LlamaCppGenerator( - model="test_model.gguf", - n_ctx=512, - n_batch=512, - model_kwargs={"n_batch": 1024}, - ) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024}) assert generator.model_kwargs["n_batch"] == 1024 def test_raises_error_without_warm_up(self): @@ -201,14 +185,9 @@ def test_run_rag_pipeline(self, generator): """ rag_pipeline = Pipeline() rag_pipeline.add_component( - instance=InMemoryBM25Retriever( - document_store=InMemoryDocumentStore(), top_k=1 - ), - name="retriever", - ) - rag_pipeline.add_component( - instance=PromptBuilder(template=prompt_template), name="prompt_builder" + instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore(), top_k=1), name="retriever" ) + rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") rag_pipeline.add_component(instance=generator, name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("retriever", "prompt_builder.documents") @@ -222,9 +201,7 @@ def test_run_rag_pipeline(self, generator): Document(content="The capital of Canada is Ottawa."), Document(content="The capital of Ghana is Accra."), ] - rag_pipeline.get_component("retriever").document_store.write_documents( - documents - ) + rag_pipeline.get_component("retriever").document_store.write_documents(documents) # Query and assert questions_and_answers = [ diff --git a/integrations/mistral/examples/indexing_pipeline.py b/integrations/mistral/examples/indexing_pipeline.py index 5b99ba7d6..0329fab8c 100644 --- a/integrations/mistral/examples/indexing_pipeline.py +++ b/integrations/mistral/examples/indexing_pipeline.py @@ -7,9 +7,7 @@ from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.mistral.document_embedder import ( - MistralDocumentEmbedder, -) +from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder document_store = InMemoryDocumentStore() fetcher = LinkContentFetcher() diff --git a/integrations/mistral/examples/streaming_chat_with_rag.py b/integrations/mistral/examples/streaming_chat_with_rag.py index c1c929606..6c7f015d8 100644 --- a/integrations/mistral/examples/streaming_chat_with_rag.py +++ b/integrations/mistral/examples/streaming_chat_with_rag.py @@ -11,12 +11,8 @@ from haystack.components.writers import DocumentWriter from haystack.dataclasses import ChatMessage from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.mistral.document_embedder import ( - MistralDocumentEmbedder, -) -from haystack_integrations.components.embedders.mistral.text_embedder import ( - MistralTextEmbedder, -) +from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder +from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder from haystack_integrations.components.generators.mistral import MistralChatGenerator document_store = InMemoryDocumentStore() @@ -46,11 +42,7 @@ prompt_builder = ChatPromptBuilder(variables=["documents"]) llm = MistralChatGenerator(streaming_callback=print_streaming_chunk) -messages = [ - ChatMessage.from_user( - "Here are some the documents: {{documents}} \\n Answer: {{query}}" - ) -] +messages = [ChatMessage.from_user("Here are some the documents: {{documents}} \\n Answer: {{query}}")] rag_pipeline = Pipeline() rag_pipeline.add_component("text_embedder", text_embedder) @@ -68,10 +60,7 @@ result = rag_pipeline.run( { "text_embedder": {"text": question}, - "prompt_builder": { - "template_variables": {"query": question}, - "template": messages, - }, + "prompt_builder": {"template_variables": {"query": question}, "template": messages}, "llm": {"generation_kwargs": {"max_tokens": 165}}, } ) diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 0c926b25c..181397c00 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -7,9 +7,7 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret -from haystack_integrations.components.generators.mistral.chat.chat_generator import ( - MistralChatGenerator, -) +from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -28,9 +26,7 @@ def mock_chat_completion(): """ Mock the OpenAI API completion response and reuse it for tests """ - with patch( - "openai.resources.chat.completions.Completions.create" - ) as mock_chat_completion_create: + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", model="mistral-tiny", @@ -40,9 +36,7 @@ def mock_chat_completion(): finish_reason="stop", logprobs=None, index=0, - message=ChatCompletionMessage( - content="Hello world!", role="assistant" - ), + message=ChatCompletionMessage(content="Hello world!", role="assistant"), ) ], created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), @@ -65,9 +59,7 @@ def test_init_default(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("MISTRAL_API_KEY", raising=False) - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): MistralChatGenerator() def test_init_with_parameters(self): @@ -81,10 +73,7 @@ def test_init_with_parameters(self): assert component.client.api_key == "test-api-key" assert component.model == "mistral-small" assert component.streaming_callback is print_streaming_chunk - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") @@ -93,11 +82,7 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["MISTRAL_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-tiny", "organization": None, "streaming_callback": None, @@ -124,10 +109,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "api_base_url": "test-base-url", "organization": None, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -136,28 +118,18 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["MISTRAL_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-small", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } component = MistralChatGenerator.from_dict(data) assert component.model == "mistral-small" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("MISTRAL_API_KEY") def test_from_dict_fail_wo_env_var(self, monkeypatch): @@ -165,28 +137,17 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", "init_parameters": { - "api_key": { - "env_vars": ["MISTRAL_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-small", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - with pytest.raises( - ValueError, match="None of the .* environment variables are set" - ): + with pytest.raises(ValueError, match="None of the .* environment variables are set"): MistralChatGenerator.from_dict(data) - def test_run( - self, chat_messages, mock_chat_completion, monkeypatch - ): # noqa: ARG002 + def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002 monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") component = MistralChatGenerator() response = component.run(chat_messages) @@ -200,9 +161,7 @@ def test_run( def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") - component = MistralChatGenerator( - generation_kwargs={"max_tokens": 10, "temperature": 0.5} - ) + component = MistralChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) response = component.run(chat_messages) # check that the component calls the OpenAI API with the correct parameters @@ -221,11 +180,7 @@ def test_check_abnormal_completions(self, caplog): component = MistralChatGenerator(api_key=Secret.from_token("test-api-key")) messages = [ ChatMessage.from_assistant( - "", - meta={ - "finish_reason": "content_filter" if i % 2 == 0 else "length", - "index": i, - }, + "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} ) for i, _ in enumerate(range(4)) ] @@ -289,9 +244,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = MistralChatGenerator(streaming_callback=callback) - results = component.run( - [ChatMessage.from_user("What's the capital of France?")] - ) + results = component.run([ChatMessage.from_user("What's the capital of France?")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] diff --git a/integrations/mistral/tests/test_mistral_document_embedder.py b/integrations/mistral/tests/test_mistral_document_embedder.py index 678f84f42..6e5c11759 100644 --- a/integrations/mistral/tests/test_mistral_document_embedder.py +++ b/integrations/mistral/tests/test_mistral_document_embedder.py @@ -6,9 +6,7 @@ import pytest from haystack import Document from haystack.utils import Secret -from haystack_integrations.components.embedders.mistral.document_embedder import ( - MistralDocumentEmbedder, -) +from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder pytestmark = pytest.mark.embedders @@ -58,11 +56,7 @@ def test_to_dict(self, monkeypatch): assert component_dict == { "type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["MISTRAL_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-embed", "api_base_url": "https://api.mistral.ai/v1", "dimensions": None, @@ -93,11 +87,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert component_dict == { "type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "mistral-embed-v2", "dimensions": None, "api_base_url": "https://custom-api-base-url.com", @@ -121,10 +111,7 @@ def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] result = embedder.run(docs) diff --git a/integrations/mistral/tests/test_mistral_text_embedder.py b/integrations/mistral/tests/test_mistral_text_embedder.py index e194e1de7..af004b022 100644 --- a/integrations/mistral/tests/test_mistral_text_embedder.py +++ b/integrations/mistral/tests/test_mistral_text_embedder.py @@ -5,9 +5,7 @@ import pytest from haystack.utils import Secret -from haystack_integrations.components.embedders.mistral.text_embedder import ( - MistralTextEmbedder, -) +from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder pytestmark = pytest.mark.embedders @@ -44,11 +42,7 @@ def test_to_dict(self, monkeypatch): assert component_dict == { "type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["MISTRAL_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, "model": "mistral-embed", "api_base_url": "https://api.mistral.ai/v1", "dimensions": None, @@ -71,11 +65,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert component_dict == { "type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "mistral-embed-v2", "api_base_url": "https://custom-api-base-url.com", "dimensions": None, diff --git a/integrations/mongodb_atlas/examples/example.py b/integrations/mongodb_atlas/examples/example.py index 5ff7a6bbd..4cd3edc21 100644 --- a/integrations/mongodb_atlas/examples/example.py +++ b/integrations/mongodb_atlas/examples/example.py @@ -10,18 +10,11 @@ from haystack import Pipeline from haystack.components.converters import MarkdownToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter -from haystack_integrations.components.retrievers.mongodb_atlas import ( - MongoDBAtlasEmbeddingRetriever, -) -from haystack_integrations.document_stores.mongodb_atlas import ( - MongoDBAtlasDocumentStore, -) +from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore # To use the MongoDBAtlasDocumentStore, you must have a running MongoDB Atlas database. # For details, see https://www.mongodb.com/docs/atlas/getting-started/ @@ -43,9 +36,7 @@ indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) -indexing.add_component( - "splitter", DocumentSplitter(split_by="sentence", split_length=2) -) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing.add_component("embedder", SentenceTransformersDocumentEmbedder()) indexing.add_component("writer", DocumentWriter(document_store)) indexing.connect("converter", "splitter") @@ -58,9 +49,7 @@ # Create the querying Pipeline and try a query querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) -querying.add_component( - "retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3) -) +querying.add_component("retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index 6fa2a4420..fed0a4c28 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,5 +1,3 @@ -from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import ( - MongoDBAtlasEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever __all__ = ["MongoDBAtlasEmbeddingRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 3ffe925e8..91a42e135 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -7,9 +7,7 @@ from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.mongodb_atlas import ( - MongoDBAtlasDocumentStore, -) +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @component @@ -69,9 +67,7 @@ def __init__( self.filters = filters or {} self.top_k = top_k self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -105,9 +101,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 0d5d351b1..93eb87005 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -10,9 +10,7 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.document_stores.mongodb_atlas.filters import ( - _normalize_filters, -) +from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne from pymongo.collection import Collection from pymongo.driver_info import DriverInfo @@ -58,9 +56,7 @@ class MongoDBAtlasDocumentStore: def __init__( self, *, - mongo_connection_string: Secret = Secret.from_env_var( - "MONGO_CONNECTION_STRING" - ), # noqa: B008 + mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008 database_name: str, collection_name: str, vector_search_index: str, @@ -82,9 +78,7 @@ def __init__( :raises ValueError: If the collection name contains invalid characters. """ - if collection_name and not bool( - re.match(r"^[a-zA-Z0-9\-_]+$", collection_name) - ): + if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' raise ValueError(msg) @@ -100,8 +94,7 @@ def __init__( def connection(self) -> MongoClient: if self._connection is None: self._connection = MongoClient( - self.mongo_connection_string.resolve_value(), - driver=DriverInfo(name="MongoDBAtlasHaystackIntegration"), + self.mongo_connection_string.resolve_value(), driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") ) return self._connection @@ -142,9 +135,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": :returns: Deserialized component. """ - deserialize_secrets_inplace( - data["init_parameters"], keys=["mongo_connection_string"] - ) + deserialize_secrets_inplace(data["init_parameters"], keys=["mongo_connection_string"]) return default_from_dict(cls, data) def count_documents(self) -> int: @@ -155,9 +146,7 @@ def count_documents(self) -> int: """ return self.collection.count_documents({}) - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -170,14 +159,10 @@ def filter_documents( filters = _normalize_filters(filters) if filters else None documents = list(self.collection.find(filters)) for doc in documents: - doc.pop( - "_id", None - ) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. + doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. return [Document.from_dict(doc) for doc in documents] - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes documents into the MongoDB Atlas collection. @@ -191,9 +176,7 @@ def write_documents( if len(documents) > 0: if not isinstance(documents[0], Document): - msg = ( - "param 'documents' must contain a list of objects of type Document" - ) + msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(msg) if policy == DuplicatePolicy.NONE: @@ -216,21 +199,13 @@ def write_documents( written_docs = len(documents) if policy == DuplicatePolicy.SKIP: - operations = [ - UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) - for doc in mongo_documents - ] - existing_documents = self.collection.count_documents( - {"id": {"$in": [doc.id for doc in documents]}} - ) + operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in mongo_documents] + existing_documents = self.collection.count_documents({"id": {"$in": [doc.id for doc in documents]}}) written_docs -= existing_documents elif policy == DuplicatePolicy.FAIL: operations = [InsertOne(doc) for doc in mongo_documents] else: - operations = [ - ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) - for doc in mongo_documents - ] + operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents] try: self.collection.bulk_write(operations) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index bf2ce4582..4583d6cd3 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -126,9 +126,7 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: def _not_in(field: str, value: Any) -> Dict[str, Any]: if not isinstance(value, list): - msg = ( - f"{field}'s value must be a list when using 'not in' comparator in Pinecone" - ) + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) return {field: {"$nin": value}} diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 2a9493589..453d9d16c 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -11,9 +11,7 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests from haystack.utils import Secret -from haystack_integrations.document_stores.mongodb_atlas import ( - MongoDBAtlasDocumentStore, -) +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame from pymongo import MongoClient from pymongo.driver_info import DriverInfo @@ -42,8 +40,7 @@ def document_store(self): collection_name = "test_collection_" + str(uuid4()) connection: MongoClient = MongoClient( - os.environ["MONGO_CONNECTION_STRING"], - driver=DriverInfo(name="MongoDBAtlasHaystackIntegration"), + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") ) database = connection[database_name] if collection_name in database.list_collection_names(): @@ -66,9 +63,7 @@ def test_write_documents(self, document_store: MongoDBAtlasDocumentStore): document_store.write_documents(docs, DuplicatePolicy.FAIL) def test_write_blob(self, document_store: MongoDBAtlasDocumentStore): - bytestream = ByteStream( - b"test", meta={"meta_key": "meta_value"}, mime_type="mime_type" - ) + bytestream = ByteStream(b"test", meta={"meta_key": "meta_value"}, mime_type="mime_type") docs = [Document(blob=bytestream)] document_store.write_documents(docs) retrieved_docs = document_store.filter_documents() @@ -83,11 +78,7 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore): def test_to_dict(self, document_store): serialized_store = document_store.to_dict() - assert ( - serialized_store["init_parameters"] - .pop("collection_name") - .startswith("test_collection_") - ) + assert serialized_store["init_parameters"].pop("collection_name").startswith("test_collection_") assert serialized_store == { "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", "init_parameters": { @@ -121,9 +112,7 @@ def test_from_dict(self): }, } ) - assert docstore.mongo_connection_string == Secret.from_env_var( - "MONGO_CONNECTION_STRING" - ) + assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING") assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" @@ -144,11 +133,7 @@ def test_complex_filter(self, document_store, filterable_docs): "operator": "AND", "conditions": [ {"field": "meta.page", "operator": "==", "value": "90"}, - { - "field": "meta.chapter", - "operator": "==", - "value": "conclusion", - }, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, ], }, ], @@ -162,8 +147,6 @@ def test_complex_filter(self, document_store, filterable_docs): d for d in filterable_docs if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") - or ( - d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion" - ) + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") ], ) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index a92573763..a03c735e0 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -6,9 +6,7 @@ import pytest from haystack.document_stores.errors import DocumentStoreError -from haystack_integrations.document_stores.mongodb_atlas import ( - MongoDBAtlasDocumentStore, -) +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @pytest.mark.skipif( @@ -24,9 +22,7 @@ def test_embedding_retrieval_cosine_similarity(self): vector_search_index="cosine_index", ) query_embedding = [0.1] * 768 - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" @@ -39,9 +35,7 @@ def test_embedding_retrieval_dot_product(self): vector_search_index="dotProduct_index", ) query_embedding = [0.1] * 768 - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" @@ -54,9 +48,7 @@ def test_embedding_retrieval_euclidean(self): vector_search_index="euclidean_index", ) query_embedding = [0.1] * 768 - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document C" assert results[1].content == "Document B" @@ -108,9 +100,7 @@ def test_embedding_retrieval_with_filters(self): ) query_embedding = [0.1] * 768 filters = {"field": "content", "operator": "!=", "value": "Document A"} - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters=filters - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters=filters) assert len(results) == 2 for doc in results: assert doc.content != "Document A" diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 2dfe058bb..56eec928f 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -7,15 +7,12 @@ from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.mongodb_atlas import ( - MongoDBAtlasEmbeddingRetriever, -) -from haystack_integrations.document_stores.mongodb_atlas import ( - MongoDBAtlasDocumentStore, -) +from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore class TestRetriever: + @pytest.fixture def mock_client(self): with patch( @@ -23,9 +20,7 @@ def mock_client(self): ) as mock_mongo_client: mock_connection = MagicMock() mock_database = MagicMock() - mock_collection_names = MagicMock( - return_value=["test_embeddings_collection"] - ) + mock_collection_names = MagicMock(return_value=["test_embeddings_collection"]) mock_database.list_collection_names = mock_collection_names mock_connection.__getitem__.return_value = mock_database mock_mongo_client.return_value = mock_connection @@ -39,15 +34,11 @@ def test_init_default(self): assert retriever.top_k == 10 assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = MongoDBAtlasEmbeddingRetriever( - document_store=mock_store, filter_policy="merge" - ) + retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store, filter_policy="merge") assert retriever.filter_policy == FilterPolicy.MERGE with pytest.raises(ValueError): - MongoDBAtlasEmbeddingRetriever( - document_store=mock_store, filter_policy="wrong_policy" - ) + MongoDBAtlasEmbeddingRetriever(document_store=mock_store, filter_policy="wrong_policy") def test_init(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) @@ -74,9 +65,7 @@ def test_init_filter_policy_merge(self): assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.MERGE - def test_to_dict( - self, mock_client, monkeypatch - ): # noqa: ARG002 mock_client appears unused but is required + def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") document_store = MongoDBAtlasDocumentStore( @@ -85,9 +74,7 @@ def test_to_dict( vector_search_index="cosine_index", ) - retriever = MongoDBAtlasEmbeddingRetriever( - document_store=document_store, filters={"field": "value"}, top_k=5 - ) + retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501 @@ -111,9 +98,7 @@ def test_to_dict( }, } - def test_from_dict( - self, mock_client, monkeypatch - ): # noqa: ARG002 mock_client appears unused but is required + def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") data = { @@ -150,9 +135,7 @@ def test_from_dict( assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.REPLACE - def test_from_dict_no_filter_policy( - self, monkeypatch - ): # mock_client appears unused but is required + def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") data = { @@ -196,9 +179,7 @@ def test_run(self): retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.3, 0.5]) - mock_store._embedding_retrieval.assert_called_once_with( - query_embedding=[0.3, 0.5], filters={}, top_k=10 - ) + mock_store._embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10) assert res == {"documents": [doc]} @@ -208,16 +189,12 @@ def test_run_merge_policy_filter(self): mock_store._embedding_retrieval.return_value = [doc] retriever = MongoDBAtlasEmbeddingRetriever( - document_store=mock_store, - filters={"foo": "boo"}, - filter_policy=FilterPolicy.MERGE, + document_store=mock_store, filters={"foo": "boo"}, filter_policy=FilterPolicy.MERGE ) res = retriever.run(query_embedding=[0.3, 0.5], filters={"field": "value"}) mock_store._embedding_retrieval.assert_called_once_with( - query_embedding=[0.3, 0.5], - filters={"field": "value", "foo": "boo"}, - top_k=10, + query_embedding=[0.3, 0.5], filters={"field": "value", "foo": "boo"}, top_k=10 ) assert res == {"documents": [doc]} diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py index 3ab0c4649..ee25df7fd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py @@ -47,8 +47,6 @@ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: data = res.json() # Sort the embeddings by index, we don't know whether they're out of order or not - embeddings = [ - e["embedding"] for e in sorted(data["data"], key=lambda e: e["index"]) - ] + embeddings = [e["embedding"] for e in sorted(data["data"], key=lambda e: e["index"])] return embeddings, {"usage": data["usage"]} diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 4c69d9e0e..4cc805c01 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -147,24 +147,16 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( - self.prefix - + self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) - + self.suffix + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix ) texts_to_embed.append(text_to_embed) return texts_to_embed - def _embed_batch( - self, texts_to_embed: List[str], batch_size: int - ) -> Tuple[List[List[float]], Dict[str, Any]]: + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: all_embeddings: List[List[float]] = [] usage_prompt_tokens = 0 usage_total_tokens = 0 @@ -172,9 +164,7 @@ def _embed_batch( assert self.backend is not None for i in tqdm( - range(0, len(texts_to_embed), batch_size), - disable=not self.progress_bar, - desc="Calculating embeddings", + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] @@ -184,12 +174,7 @@ def _embed_batch( usage_prompt_tokens += meta.get("usage", {}).get("prompt_tokens", 0) usage_total_tokens += meta.get("usage", {}).get("total_tokens", 0) - return all_embeddings, { - "usage": { - "prompt_tokens": usage_prompt_tokens, - "total_tokens": usage_total_tokens, - } - } + return all_embeddings, {"usage": {"prompt_tokens": usage_prompt_tokens, "total_tokens": usage_total_tokens}} @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document]): @@ -212,11 +197,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - elif ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + elif not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "NvidiaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the NvidiaTextEmbedder." diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py index 82b57bd6e..5253b3254 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py @@ -73,9 +73,7 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: if "finish_reason" in choice: choice_meta["finish_reason"] = choice["finish_reason"] if "completion_tokens" in completions["usage"]: - choice_meta["usage"]["completion_tokens"] = completions["usage"][ - "completion_tokens" - ] + choice_meta["usage"]["completion_tokens"] = completions["usage"]["completion_tokens"] meta.append(choice_meta) diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 65ef3d01f..6aea421dd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -64,9 +64,7 @@ def __init__( to know the supported arguments. """ self._model = model - self._api_url = url_validation( - api_url, _DEFAULT_API_URL, ["v1/chat/completions"] - ) + self._api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/chat/completions"]) self._api_key = api_key self._model_arguments = model_arguments or {} diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 41b5cb7ab..072807685 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,8 +1,5 @@ import pytest -from haystack_integrations.components.embedders.nvidia import ( - NvidiaDocumentEmbedder, - NvidiaTextEmbedder, -) +from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder from haystack_integrations.components.generators.nvidia import NvidiaGenerator diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 1205732e5..856ae4652 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -3,10 +3,7 @@ import pytest from haystack import Document from haystack.utils import Secret -from haystack_integrations.components.embedders.nvidia import ( - EmbeddingTruncateMode, - NvidiaDocumentEmbedder, -) +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaDocumentEmbedder from haystack_integrations.components.embedders.nvidia.backend import EmbedderBackend @@ -51,9 +48,7 @@ def test_init_with_parameters(self): assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model == "nvolveqa_40k" - assert ( - embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" - ) + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" assert embedder.batch_size == 30 @@ -74,11 +69,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://ai.api.nvidia.com/v1/retrieval/nvidia", "model": "playground_nvolveqa_40k", "prefix": "", @@ -108,11 +99,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://example.com/v1", "model": "playground_nvolveqa_40k", "prefix": "prefix", @@ -130,11 +117,7 @@ def from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://example.com", "model": "playground_nvolveqa_40k", "prefix": "prefix", @@ -159,11 +142,7 @@ def from_dict(self, monkeypatch): def test_prepare_texts_to_embed_w_metadata(self): documents = [ - Document( - content=f"document number {i}:\ncontent", - meta={"meta_field": f"meta_value {i}"}, - ) - for i in range(5) + Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] embedder = NvidiaDocumentEmbedder( @@ -229,10 +208,7 @@ def test_embed_batch(self): def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "playground_nvolveqa_40k" @@ -265,10 +241,7 @@ def test_run(self): def test_run_custom_batch_size(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "playground_nvolveqa_40k" embedder = NvidiaDocumentEmbedder( @@ -300,9 +273,7 @@ def test_run_custom_batch_size(self): assert metadata == {"usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} def test_run_wrong_input_format(self): - embedder = NvidiaDocumentEmbedder( - "playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key") - ) + embedder = NvidiaDocumentEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) embedder.warm_up() embedder.backend = MockBackend("aa", None) @@ -310,22 +281,14 @@ def test_run_wrong_input_format(self): string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, - match="NvidiaDocumentEmbedder expects a list of Documents as input", - ): + with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=string_input) - with pytest.raises( - TypeError, - match="NvidiaDocumentEmbedder expects a list of Documents as input", - ): + with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): - embedder = NvidiaDocumentEmbedder( - "playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key") - ) + embedder = NvidiaDocumentEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) embedder.warm_up() embedder.backend = MockBackend("aa", None) @@ -341,8 +304,7 @@ def test_run_on_empty_list(self): reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", ) @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) - or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", ) @@ -358,10 +320,7 @@ def test_run_integration_with_nim_backend(self): embedder.warm_up() docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] result = embedder.run(docs) @@ -388,10 +347,7 @@ def test_run_integration_with_api_catalog(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] result = embedder.run(docs) diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 998203c40..3ddeebe88 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -55,11 +55,7 @@ def test_to_dict(self, monkeypatch): "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", "init_parameters": { "api_url": "https://integrate.api.nvidia.com/v1", - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "model": "playground_nemotron_steerlm_8b", "model_arguments": {}, }, @@ -83,11 +79,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://my.url.com/v1", "model": "playground_nemotron_steerlm_8b", "model_arguments": { @@ -102,8 +94,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): } @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) - or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_GENERATOR_MODEL containing the hosted model name and " "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", ) diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index eeb09badb..42d60dee2 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -2,10 +2,7 @@ import pytest from haystack.utils import Secret -from haystack_integrations.components.embedders.nvidia import ( - EmbeddingTruncateMode, - NvidiaTextEmbedder, -) +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaTextEmbedder from haystack_integrations.components.embedders.nvidia.backend import EmbedderBackend @@ -41,9 +38,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == Secret.from_token("fake-api-key") assert embedder.model == "nvolveqa_40k" - assert ( - embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" - ) + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" @@ -60,11 +55,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://ai.api.nvidia.com/v1/retrieval/nvidia", "model": "nvolveqa_40k", "prefix": "", @@ -86,11 +77,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): assert data == { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://example.com/v1", "model": "nvolveqa_40k", "prefix": "prefix", @@ -104,11 +91,7 @@ def from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { - "api_key": { - "env_vars": ["NVIDIA_API_KEY"], - "strict": True, - "type": "env_var", - }, + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "api_url": "https://example.com", "model": "nvolveqa_40k", "prefix": "prefix", @@ -125,10 +108,7 @@ def from_dict(self, monkeypatch): def test_run(self): embedder = NvidiaTextEmbedder( - "playground_nvolveqa_40k", - api_key=Secret.from_token("fake-api-key"), - prefix="prefix ", - suffix=" suffix", + "playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key"), prefix="prefix ", suffix=" suffix" ) embedder.warm_up() @@ -143,22 +123,17 @@ def test_run(self): } def test_run_wrong_input_format(self): - embedder = NvidiaTextEmbedder( - "playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key") - ) + embedder = NvidiaTextEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) embedder.warm_up() embedder.backend = MockBackend("aa", None) list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="NvidiaTextEmbedder expects a string as an input" - ): + with pytest.raises(TypeError, match="NvidiaTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) - or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", ) diff --git a/integrations/ollama/examples/chat_generator_example.py b/integrations/ollama/examples/chat_generator_example.py index 32b69d9a3..065a6a1b4 100644 --- a/integrations/ollama/examples/chat_generator_example.py +++ b/integrations/ollama/examples/chat_generator_example.py @@ -16,9 +16,7 @@ ), ChatMessage.from_user("How do I get started?"), ] -client = OllamaChatGenerator( - model="orca-mini", timeout=45, url="http://localhost:11434/api/chat" -) +client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434/api/chat") response = client.run(messages, generation_kwargs={"temperature": 0.2}) diff --git a/integrations/ollama/examples/embedders_example.py b/integrations/ollama/examples/embedders_example.py index 2a2ca9467..21a36e2a2 100644 --- a/integrations/ollama/examples/embedders_example.py +++ b/integrations/ollama/examples/embedders_example.py @@ -7,12 +7,8 @@ from haystack import Document, Pipeline from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.embedders.ollama.document_embedder import ( - OllamaDocumentEmbedder, -) -from haystack_integrations.components.embedders.ollama.text_embedder import ( - OllamaTextEmbedder, -) +from haystack_integrations.components.embedders.ollama.document_embedder import OllamaDocumentEmbedder +from haystack_integrations.components.embedders.ollama.text_embedder import OllamaTextEmbedder document_store = InMemoryDocumentStore(embedding_similarity_function="cosine") @@ -28,9 +24,7 @@ query_pipeline = Pipeline() query_pipeline.add_component("text_embedder", OllamaTextEmbedder()) -query_pipeline.add_component( - "retriever", InMemoryEmbeddingRetriever(document_store=document_store) -) +query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store)) query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") query = "Who lives in Berlin?" diff --git a/integrations/ollama/examples/generator_example.py b/integrations/ollama/examples/generator_example.py index b2c7893cf..b674e676b 100644 --- a/integrations/ollama/examples/generator_example.py +++ b/integrations/ollama/examples/generator_example.py @@ -15,9 +15,7 @@ document_store.write_documents( [ Document(content="Super Mario was an important politician"), - Document( - content="Mario owns several castles and uses them to conduct important political business" - ), + Document(content="Mario owns several castles and uses them to conduct important political business"), Document( content="Super Mario was a successful military leader who fought off several invasion attempts by " "his arch rival - Bowser" diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index fd9251cb1..b5783c611 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -59,17 +59,11 @@ def __init__( self.suffix = suffix self.prefix = prefix - def _create_json_payload( - self, text: str, generation_kwargs: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: + def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service """ - return { - "model": self.model, - "prompt": text, - "options": {**self.generation_kwargs, **(generation_kwargs or {})}, - } + return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ @@ -87,21 +81,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: meta_values_to_embed = [] text_to_embed = ( - self.prefix - + self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) - + self.suffix + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix ).replace("\n", " ") texts_to_embed.append(text_to_embed) return texts_to_embed def _embed_batch( - self, - texts_to_embed: List[str], - batch_size: int, - generation_kwargs: Optional[Dict[str, Any]] = None, + self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None ): """ Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. @@ -113,9 +100,7 @@ def _embed_batch( meta: Dict[str, Any] = {"model": ""} for i in tqdm( - range(0, len(texts_to_embed), batch_size), - disable=not self.progress_bar, - desc="Calculating embeddings", + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i] # Single batch only payload = self._create_json_payload(batch, generation_kwargs) @@ -129,11 +114,7 @@ def _embed_batch( return all_embeddings, meta @component.output_types(documents=List[Document], meta=Dict[str, Any]) - def run( - self, - documents: List[Document], - generation_kwargs: Optional[Dict[str, Any]] = None, - ): + def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): """ Runs an Ollama Model to compute embeddings of the provided documents. @@ -147,11 +128,7 @@ def run( - `documents`: Documents with embedding information attached - `meta`: The metadata collected during the embedding process """ - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "OllamaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the OllamaTextEmbedder." @@ -160,9 +137,7 @@ def run( texts_to_embed = self._prepare_texts_to_embed(documents=documents) embeddings, meta = self._embed_batch( - texts_to_embed=texts_to_embed, - batch_size=self.batch_size, - generation_kwargs=generation_kwargs, + texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs ) for doc, emb in zip(documents, embeddings): diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index c591fc5df..5a28ba393 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -44,17 +44,11 @@ def __init__( self.url = url self.model = model - def _create_json_payload( - self, text: str, generation_kwargs: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: + def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service """ - return { - "model": self.model, - "prompt": text, - "options": {**self.generation_kwargs, **(generation_kwargs or {})}, - } + return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index f9dfded91..a95d8c4fb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -69,9 +69,7 @@ def __init__( def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _create_json_payload( - self, messages: List[ChatMessage], stream=False, generation_kwargs=None - ) -> Dict[str, Any]: + def _create_json_payload(self, messages: List[ChatMessage], stream=False, generation_kwargs=None) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service """ @@ -84,22 +82,16 @@ def _create_json_payload( "options": generation_kwargs, } - def _build_message_from_ollama_response( - self, ollama_response: Response - ) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ json_content = ollama_response.json() message = ChatMessage.from_assistant(content=json_content["message"]["content"]) - message.meta.update( - {key: value for key, value in json_content.items() if key != "message"} - ) + message.meta.update({key: value for key, value in json_content.items() if key != "message"}) return message - def _convert_to_streaming_response( - self, chunks: List[StreamingChunk] - ) -> Dict[str, List[Any]]: + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ Converts a list of chunks response required Haystack format. """ @@ -160,9 +152,7 @@ def run( json_payload = self._create_json_payload(messages, stream, generation_kwargs) - response = requests.post( - url=self.url, json=json_payload, timeout=self.timeout, stream=stream - ) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) # throw error on unsuccessful response response.raise_for_status() diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 33b7ba354..50c65b650 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -4,10 +4,7 @@ import requests from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk -from haystack.utils.callable_serialization import ( - deserialize_callable, - serialize_callable, -) +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from requests import Response @@ -80,11 +77,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = ( - serialize_callable(self.streaming_callback) - if self.streaming_callback - else None - ) + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, timeout=self.timeout, @@ -110,14 +103,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable( - serialized_callback_handler - ) + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _create_json_payload( - self, prompt: str, stream: bool, generation_kwargs=None - ) -> Dict[str, Any]: + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. """ @@ -144,17 +133,13 @@ def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any] return {"replies": replies, "meta": [meta]} - def _convert_to_streaming_response( - self, chunks: List[StreamingChunk] - ) -> Dict[str, List[Any]]: + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ Converts a list of chunks response required Haystack format. """ replies = ["".join([c.content for c in chunks])] - meta = { - key: value for key, value in chunks[0].meta.items() if key != "response" - } + meta = {key: value for key, value in chunks[0].meta.items() if key != "response"} return {"replies": replies, "meta": [meta]} @@ -207,9 +192,7 @@ def run( json_payload = self._create_json_payload(prompt, stream, generation_kwargs) - response = requests.post( - url=self.url, json=json_payload, timeout=self.timeout, stream=stream - ) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) # throw error on unsuccessful response response.raise_for_status() diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 8fa492430..e09208bb8 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -10,12 +10,9 @@ @pytest.fixture def chat_messages() -> List[ChatMessage]: return [ - ChatMessage.from_user( - "Tell me about why Super Mario is the greatest superhero" - ), + ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"), ChatMessage.from_assistant( - "Super Mario has prevented Bowser from destroying the world", - {"something": "something"}, + "Super Mario has prevented Bowser from destroying the world", {"something": "something"} ), ] @@ -49,14 +46,8 @@ def test_create_json_payload(self, chat_messages): ) expected = { "messages": [ - { - "role": "user", - "content": "Tell me about why Super Mario is the greatest superhero", - }, - { - "role": "assistant", - "content": "Super Mario has prevented Bowser from destroying the world", - }, + {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}, + {"role": "assistant", "content": "Super Mario has prevented Bowser from destroying the world"}, ], "model": "some_model", "stream": False, @@ -83,9 +74,7 @@ def test_build_message_from_ollama_response(self): "eval_duration": 4799921000, } - observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response( - mock_ollama_response - ) + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(mock_ollama_response) assert observed.role == "assistant" assert observed.content == "Hello! How are you today?" @@ -114,31 +103,20 @@ def test_run_with_chat_history(self): chat_generator = OllamaChatGenerator() chat_history = [ - { - "role": "user", - "content": "What is the largest city in the United Kingdom by population?", - }, - { - "role": "assistant", - "content": "London is the largest city in the United Kingdom by population", - }, + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, {"role": "user", "content": "And what is the second largest?"}, ] chat_messages = [ - ChatMessage( - role=ChatRole(message["role"]), content=message["content"], name=None - ) + ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) for message in chat_history ] response = chat_generator.run(chat_messages) assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert ( - "Manchester" in response["replies"][-1].content - or "Glasgow" in response["replies"][-1].content - ) + assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content @pytest.mark.integration def test_run_model_unavailable(self): @@ -156,21 +134,13 @@ def test_run_with_streaming(self): chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) chat_history = [ - { - "role": "user", - "content": "What is the largest city in the United Kingdom by population?", - }, - { - "role": "assistant", - "content": "London is the largest city in the United Kingdom by population", - }, + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, {"role": "user", "content": "And what is the second largest?"}, ] chat_messages = [ - ChatMessage( - role=ChatRole(message["role"]), content=message["content"], name=None - ) + ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) for message in chat_history ] @@ -180,7 +150,4 @@ def test_run_with_streaming(self): assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert ( - "Manchester" in response["replies"][-1].content - or "Glasgow" in response["replies"][-1].content - ) + assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index 087ae1845..012ad9eae 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -47,7 +47,5 @@ def test_run(self): reply = embedder.run(list_of_docs) assert isinstance(reply, dict) - assert all( - isinstance(element, float) for element in reply["documents"][0].embedding - ) + assert all(isinstance(element, float) for element in reply["documents"][0].embedding) assert reply["meta"]["model"] == "nomic-embed-text" diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 17edcff0e..4af2bdb82 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -103,10 +103,7 @@ def test_to_dict_with_parameters(self): "model": "llama2", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } @@ -121,20 +118,14 @@ def test_from_dict(self): "model": "llama2", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": { - "max_tokens": 10, - "some_test_param": "test-params", - }, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } component = OllamaGenerator.from_dict(data) assert component.model == "llama2" assert component.streaming_callback is print_streaming_chunk assert component.url == "going_to_51_pegasi_b_for_weekend" - assert component.generation_kwargs == { - "max_tokens": 10, - "some_test_param": "test-params", - } + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @pytest.mark.parametrize( "configuration", diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index b75d7e1cd..29e242234 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -20,7 +20,6 @@ class OpenSearchBM25Retriever: BM25 computes a weighted word overlap between the query string and a document to determine its similarity. """ - def __init__( self, *, @@ -92,9 +91,7 @@ def __init__( self._scale_score = scale_score self._all_terms_must_match = all_terms_must_match self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure @@ -136,9 +133,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if "filter_policy" in data["init_parameters"]: - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - data["init_parameters"]["filter_policy"] - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 3e839085a..eba5596f2 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -82,9 +82,7 @@ def __init__( self._filters = filters or {} self._top_k = top_k self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure @@ -124,9 +122,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if "filter_policy" in data["init_parameters"]: - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - data["init_parameters"]["filter_policy"] - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py index c9427afa9..8249c16ca 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py @@ -57,9 +57,7 @@ def _get_aws_session( profile_name=aws_profile_name, ) except BotoCoreError as e: - provided_aws_config = { - k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS - } + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" raise AWSConfigurationError(msg) from e @@ -78,9 +76,7 @@ class AWSAuth: default_factory=lambda: Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False) ) aws_secret_access_key: Optional[Secret] = field( - default_factory=lambda: Secret.from_env_var( - "AWS_SECRET_ACCESS_KEY", strict=False - ) + default_factory=lambda: Secret.from_env_var("AWS_SECRET_ACCESS_KEY", strict=False) ) aws_session_token: Optional[Secret] = field( default_factory=lambda: Secret.from_env_var("AWS_SESSION_TOKEN", strict=False) @@ -88,9 +84,7 @@ class AWSAuth: aws_region_name: Optional[Secret] = field( default_factory=lambda: Secret.from_env_var("AWS_DEFAULT_REGION", strict=False) ) - aws_profile_name: Optional[Secret] = field( - default_factory=lambda: Secret.from_env_var("AWS_PROFILE", strict=False) - ) + aws_profile_name: Optional[Secret] = field(default_factory=lambda: Secret.from_env_var("AWS_PROFILE", strict=False)) aws_service: str = field(default="es") def __post_init__(self) -> None: @@ -107,9 +101,7 @@ def to_dict(self) -> Dict[str, Any]: for _field in fields(self): field_value = getattr(self, _field.name) if _field.type == Optional[Secret]: - _fields[_field.name] = ( - field_value.to_dict() if field_value is not None else None - ) + _fields[_field.name] = field_value.to_dict() if field_value is not None else None else: _fields[_field.name] = field_value @@ -123,13 +115,7 @@ def from_dict(cls, data: Dict[str, Any]) -> Optional["AWSAuth"]: init_parameters = data.get("init_parameters", {}) deserialize_secrets_inplace( init_parameters, - [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index cafab0ba5..465897608 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -106,11 +106,7 @@ def __init__( def _get_default_mappings(self) -> Dict[str, Any]: default_mappings: Dict[str, Any] = { "properties": { - "embedding": { - "type": "knn_vector", - "index": True, - "dimension": self._embedding_dim, - }, + "embedding": {"type": "knn_vector", "index": True, "dimension": self._embedding_dim}, "content": {"type": "text"}, }, "dynamic_templates": [ @@ -175,9 +171,7 @@ def create_index( settings = self._settings if not self.client.indices.exists(index=index): - self.client.indices.create( - index=index, body={"mappings": mappings, "settings": settings} - ) + self.client.indices.create(index=index, body={"mappings": mappings, "settings": settings}) def to_dict(self) -> Dict[str, Any]: # This is not the best solution to serialise this class but is the fastest to implement. @@ -200,9 +194,7 @@ def to_dict(self) -> Dict[str, Any]: settings=self._settings, create_index=self._create_index, return_embedding=self._return_embedding, - http_auth=self._http_auth.to_dict() - if isinstance(self._http_auth, AWSAuth) - else self._http_auth, + http_auth=self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth, use_ssl=self._use_ssl, verify_certs=self._verify_certs, timeout=self._timeout, @@ -240,14 +232,10 @@ def _search_documents(self, **kwargs) -> List[Document]: index=self._index, body=kwargs, ) - documents: List[Document] = [ - self._deserialize_document(hit) for hit in res["hits"]["hits"] - ] + documents: List[Document] = [self._deserialize_document(hit) for hit in res["hits"]["hits"]] return documents - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: if filters and "operator" not in filters and "conditions" not in filters: filters = convert(filters) @@ -259,9 +247,7 @@ def filter_documents( return documents - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes Documents to OpenSearch. If policy is not specified or set to DuplicatePolicy.NONE, it will raise an exception if a document with the @@ -269,9 +255,7 @@ def write_documents( """ if len(documents) > 0: if not isinstance(documents[0], Document): - msg = ( - "param 'documents' must contain a list of objects of type Document" - ) + msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(msg) if policy == DuplicatePolicy.NONE: @@ -304,15 +288,9 @@ def write_documents( other_errors.append(e) continue error_type = e["create"]["error"]["type"] - if ( - policy == DuplicatePolicy.FAIL - and error_type == "version_conflict_engine_exception" - ): + if policy == DuplicatePolicy.FAIL and error_type == "version_conflict_engine_exception": duplicate_errors_ids.append(e["create"]["_id"]) - elif ( - policy == DuplicatePolicy.SKIP - and error_type == "version_conflict_engine_exception" - ): + elif policy == DuplicatePolicy.SKIP and error_type == "version_conflict_engine_exception": # when the policy is skip, duplication errors are OK and we should not raise an exception continue else: @@ -323,9 +301,7 @@ def write_documents( raise DuplicateDocumentError(msg) if len(other_errors) > 0: - msg = ( - f"Failed to write documents to OpenSearch. Errors:\n{other_errors}" - ) + msg = f"Failed to write documents to OpenSearch. Errors:\n{other_errors}" raise DocumentStoreError(msg) return documents_written @@ -415,9 +391,7 @@ def _bm25_retrieval( body["query"]["bool"]["filter"] = normalize_filters(filters) if isinstance(custom_query, dict): - body = self._render_custom_query( - custom_query, {"$query": query, "$filters": normalize_filters(filters)} - ) + body = self._render_custom_query(custom_query, {"$query": query, "$filters": normalize_filters(filters)}) else: operator = "AND" if all_terms_must_match else "OR" @@ -452,9 +426,7 @@ def _bm25_retrieval( if scale_score: for doc in documents: - doc.score = float( - 1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR))) - ) # type:ignore + doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR)))) # type:ignore return documents @@ -513,11 +485,7 @@ def _embedding_retrieval( if isinstance(custom_query, dict): body = self._render_custom_query( - custom_query, - { - "$query_embedding": query_embedding, - "$filters": normalize_filters(filters), - }, + custom_query, {"$query_embedding": query_embedding, "$filters": normalize_filters(filters)} ) else: @@ -551,9 +519,7 @@ def _embedding_retrieval( docs = self._search_documents(**body) return docs - def _render_custom_query( - self, custom_query: Any, substitutions: Dict[str, Any] - ) -> Any: + def _render_custom_query(self, custom_query: Any, substitutions: Dict[str, Any]) -> Any: """ Recursively replaces the placeholders in the custom_query with the actual values. @@ -562,15 +528,9 @@ def _render_custom_query( :returns: The custom query with the placeholders replaced. """ if isinstance(custom_query, dict): - return { - key: self._render_custom_query(value, substitutions) - for key, value in custom_query.items() - } + return {key: self._render_custom_query(value, substitutions) for key, value in custom_query.items()} elif isinstance(custom_query, list): - return [ - self._render_custom_query(entry, substitutions) - for entry in custom_query - ] + return [self._render_custom_query(entry, substitutions) for entry in custom_query] elif isinstance(custom_query, str): return substitutions.get(custom_query, custom_query) diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/filters.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/filters.py index 0662094b0..3aae3aacd 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/filters.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/filters.py @@ -53,9 +53,7 @@ def _equal(field: str, value: Any) -> Dict[str, Any]: "terms_set": { field: { "terms": value, - "minimum_should_match_script": { - "source": f"Math.max(params.num_terms, doc['{field}'].size())" - }, + "minimum_should_match_script": {"source": f"Math.max(params.num_terms, doc['{field}'].size())"}, } } } @@ -73,13 +71,7 @@ def _not_equal(field: str, value: Any) -> Dict[str, Any]: return {"bool": {"must_not": {"terms": {field: value}}}} if field in ["text", "dataframe"]: # We want to fully match the text field. - return { - "bool": { - "must_not": { - "match": {field: {"query": value, "minimum_should_match": "100%"}} - } - } - } + return {"bool": {"must_not": {"match": {field: {"query": value, "minimum_should_match": "100%"}}}}} return {"bool": {"must_not": {"term": {field: value}}}} @@ -90,14 +82,7 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -119,14 +104,7 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -148,14 +126,7 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -177,14 +148,7 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: # if it has a field set and not set at the same time. # This will cause the filter to match no Document. # This way we keep the behavior consistent with other Document Stores. - return { - "bool": { - "must": [ - {"exists": {"field": field}}, - {"bool": {"must_not": {"exists": {"field": field}}}}, - ] - } - } + return {"bool": {"must": [{"exists": {"field": field}}, {"bool": {"must_not": {"exists": {"field": field}}}}]}} if isinstance(value, str): try: datetime.fromisoformat(value) @@ -271,9 +235,7 @@ def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: ] ``` """ - range_conditions = [ - next(iter(c["range"].items())) for c in conditions if "range" in c - ] + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] if range_conditions: conditions = [c for c in conditions if "range" not in c] range_conditions_dict: Dict[str, Any] = {} diff --git a/integrations/opensearch/tests/test_auth.py b/integrations/opensearch/tests/test_auth.py index 1c3d3a42e..25bda7d66 100644 --- a/integrations/opensearch/tests/test_auth.py +++ b/integrations/opensearch/tests/test_auth.py @@ -36,31 +36,15 @@ def test_to_dict(self): assert res == { "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": { "type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False, }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "aws_service": "es", }, } @@ -69,31 +53,15 @@ def test_from_dict(self): data = { "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": { "type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False, }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "aws_service": "es", }, } @@ -136,9 +104,7 @@ def test_from_dict_disable_env_variables(self): assert aws_auth.aws_service == "aoss" assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) - @patch( - "haystack_integrations.document_stores.opensearch.auth.AWSAuth._get_urllib3_aws_v4_signer_auth" - ) + @patch("haystack_integrations.document_stores.opensearch.auth.AWSAuth._get_urllib3_aws_v4_signer_auth") def test_call(self, _get_urllib3_aws_v4_signer_auth_mock): signer_auth_mock = Mock(spec=Urllib3AWSV4SignerAuth) _get_urllib3_aws_v4_signer_auth_mock.return_value = signer_auth_mock diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index ba5ed60d7..c015d360a 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -6,13 +6,9 @@ import pytest from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from haystack_integrations.components.retrievers.opensearch import ( - OpenSearchBM25Retriever, -) +from haystack_integrations.components.retrievers.opensearch import OpenSearchBM25Retriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore -from haystack_integrations.document_stores.opensearch.document_store import ( - DEFAULT_MAX_CHUNK_BYTES, -) +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES def test_init_default(): @@ -24,9 +20,7 @@ def test_init_default(): assert not retriever._scale_score assert retriever._filter_policy == FilterPolicy.REPLACE - retriever = OpenSearchBM25Retriever( - document_store=mock_store, filter_policy="replace" - ) + retriever = OpenSearchBM25Retriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): @@ -36,9 +30,7 @@ def test_init_default(): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchBM25Retriever( - document_store=document_store, custom_query={"some": "custom query"} - ) + retriever = OpenSearchBM25Retriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", @@ -50,20 +42,11 @@ def test_to_dict(_mock_opensearch_client): "index": "default", "mappings": { "dynamic_templates": [ - { - "strings": { - "mapping": {"type": "keyword"}, - "match_mapping_type": "string", - } - } + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} ], "properties": { "content": {"type": "text"}, - "embedding": { - "dimension": 768, - "index": True, - "type": "knn_vector", - }, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, }, }, "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, @@ -219,9 +202,7 @@ def test_run_time_params(): def test_run_ignore_errors(caplog): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._bm25_retrieval.side_effect = Exception("Some error") - retriever = OpenSearchBM25Retriever( - document_store=mock_store, raise_on_failure=False - ) + retriever = OpenSearchBM25Retriever(document_store=mock_store, raise_on_failure=False) res = retriever.run(query="some query") assert len(res) == 1 assert res["documents"] == [] diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index ca3f4b944..287c24f63 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -13,9 +13,7 @@ from haystack.utils.auth import Secret from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore from haystack_integrations.document_stores.opensearch.auth import AWSAuth -from haystack_integrations.document_stores.opensearch.document_store import ( - DEFAULT_MAX_CHUNK_BYTES, -) +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES from opensearchpy.exceptions import RequestError @@ -30,21 +28,10 @@ def test_to_dict(_mock_opensearch_client): "hosts": "some hosts", "index": "default", "mappings": { - "dynamic_templates": [ - { - "strings": { - "mapping": {"type": "keyword"}, - "match_mapping_type": "string", - } - } - ], + "dynamic_templates": [{"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}}], "properties": { "content": {"type": "text"}, - "embedding": { - "dimension": 768, - "index": True, - "type": "knn_vector", - }, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, }, }, "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, @@ -115,9 +102,7 @@ def test_init_is_lazy(_mock_opensearch_client): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_get_default_mappings(_mock_opensearch_client): - store = OpenSearchDocumentStore( - hosts="testhost", embedding_dim=1536, method={"name": "hnsw"} - ) + store = OpenSearchDocumentStore(hosts="testhost", embedding_dim=1536, method={"name": "hnsw"}) assert store._mappings["properties"]["embedding"] == { "type": "knn_vector", "index": True, @@ -134,9 +119,7 @@ def mock_boto3_session(self): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_init_with_basic_auth(self, _mock_opensearch_client): - document_store = OpenSearchDocumentStore( - hosts="testhost", http_auth=("user", "pw") - ) + document_store = OpenSearchDocumentStore(hosts="testhost", http_auth=("user", "pw")) assert document_store.client _mock_opensearch_client.assert_called_once() assert _mock_opensearch_client.call_args[1]["http_auth"] == ("user", "pw") @@ -180,9 +163,7 @@ def test_from_dict_basic_auth(self, _mock_opensearch_client): assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pw"] @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") - def test_from_dict_aws_auth( - self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch - ): + def test_from_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") document_store = OpenSearchDocumentStore.from_dict( { @@ -206,9 +187,7 @@ def test_from_dict_aws_auth( @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict_basic_auth(self, _mock_opensearch_client): - document_store = OpenSearchDocumentStore( - hosts="some hosts", http_auth=("user", "pw") - ) + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=("user", "pw")) res = document_store.to_dict() assert res == { "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", @@ -218,20 +197,11 @@ def test_to_dict_basic_auth(self, _mock_opensearch_client): "index": "default", "mappings": { "dynamic_templates": [ - { - "strings": { - "mapping": {"type": "keyword"}, - "match_mapping_type": "string", - } - } + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} ], "properties": { "content": {"type": "text"}, - "embedding": { - "dimension": 768, - "index": True, - "type": "knn_vector", - }, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, }, }, "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, @@ -247,13 +217,9 @@ def test_to_dict_basic_auth(self, _mock_opensearch_client): } @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") - def test_to_dict_aws_auth( - self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch - ): + def test_to_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") - document_store = OpenSearchDocumentStore( - hosts="some hosts", http_auth=AWSAuth() - ) + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=AWSAuth()) res = document_store.to_dict() assert res == { "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", @@ -263,20 +229,11 @@ def test_to_dict_aws_auth( "index": "default", "mappings": { "dynamic_templates": [ - { - "strings": { - "mapping": {"type": "keyword"}, - "match_mapping_type": "string", - } - } + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} ], "properties": { "content": {"type": "text"}, - "embedding": { - "dimension": 768, - "index": True, - "type": "knn_vector", - }, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, }, }, "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, @@ -287,31 +244,15 @@ def test_to_dict_aws_auth( "http_auth": { "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", "init_parameters": { - "aws_access_key_id": { - "type": "env_var", - "env_vars": ["AWS_ACCESS_KEY_ID"], - "strict": False, - }, + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": { "type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False, }, - "aws_session_token": { - "type": "env_var", - "env_vars": ["AWS_SESSION_TOKEN"], - "strict": False, - }, - "aws_region_name": { - "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], - "strict": False, - }, - "aws_profile_name": { - "type": "env_var", - "env_vars": ["AWS_PROFILE"], - "strict": False, - }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "aws_service": "es", }, }, @@ -369,13 +310,9 @@ def document_store_readonly(self, request): method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, create_index=False, ) - store.client.cluster.put_settings( - body={"transient": {"action.auto_create_index": False}} - ) + store.client.cluster.put_settings(body={"transient": {"action.auto_create_index": False}}) yield store - store.client.cluster.put_settings( - body={"transient": {"action.auto_create_index": True}} - ) + store.client.cluster.put_settings(body={"transient": {"action.auto_create_index": True}}) store.client.indices.delete(index=index, params={"ignore": [400, 404]}) @pytest.fixture @@ -399,9 +336,7 @@ def document_store_embedding_dim_4(self, request): yield store store.client.indices.delete(index=index, params={"ignore": [400, 404]}) - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set. We don't want to compare the score, so we set it to None before comparing the documents. @@ -432,18 +367,14 @@ def test_write_documents(self, document_store: OpenSearchDocumentStore): with pytest.raises(DuplicateDocumentError): document_store.write_documents(docs, DuplicatePolicy.FAIL) - def test_write_documents_readonly( - self, document_store_readonly: OpenSearchDocumentStore - ): + def test_write_documents_readonly(self, document_store_readonly: OpenSearchDocumentStore): docs = [Document(id="1")] with pytest.raises(DocumentStoreError, match="index_not_found_exception"): document_store_readonly.write_documents(docs) def test_create_index(self, document_store_readonly: OpenSearchDocumentStore): document_store_readonly.create_index() - assert document_store_readonly.client.indices.exists( - index=document_store_readonly._index - ) + assert document_store_readonly.client.indices.exists(index=document_store_readonly._index) def test_bm25_retrieval(self, document_store: OpenSearchDocumentStore): document_store.write_documents( @@ -496,9 +427,7 @@ def test_bm25_retrieval_pagination(self, document_store: OpenSearchDocumentStore assert len(res) == 11 assert all("programming" in doc.content for doc in res) - def test_bm25_retrieval_all_terms_must_match( - self, document_store: OpenSearchDocumentStore - ): + def test_bm25_retrieval_all_terms_must_match(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ Document(content="Haskell is a functional programming language"), @@ -515,15 +444,11 @@ def test_bm25_retrieval_all_terms_must_match( ] ) - res = document_store._bm25_retrieval( - "functional Haskell", top_k=3, all_terms_must_match=True - ) + res = document_store._bm25_retrieval("functional Haskell", top_k=3, all_terms_must_match=True) assert len(res) == 1 assert "Haskell is a functional programming language" in res[0].content - def test_bm25_retrieval_all_terms_must_match_false( - self, document_store: OpenSearchDocumentStore - ): + def test_bm25_retrieval_all_terms_must_match_false(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ Document(content="Haskell is a functional programming language"), @@ -540,9 +465,7 @@ def test_bm25_retrieval_all_terms_must_match_false( ] ) - res = document_store._bm25_retrieval( - "functional Haskell", top_k=10, all_terms_must_match=False - ) + res = document_store._bm25_retrieval("functional Haskell", top_k=10, all_terms_must_match=False) assert len(res) == 5 assert "functional" in res[0].content assert "functional" in res[1].content @@ -550,9 +473,7 @@ def test_bm25_retrieval_all_terms_must_match_false( assert "functional" in res[3].content assert "functional" in res[4].content - def test_bm25_retrieval_with_fuzziness( - self, document_store: OpenSearchDocumentStore - ): + def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ Document(content="Haskell is a functional programming language"), @@ -652,9 +573,7 @@ def test_bm25_retrieval_with_filters(self, document_store: OpenSearchDocumentSto retrieved_ids = sorted([doc.id for doc in res]) assert retrieved_ids == ["1", "2", "3", "4", "5"] - def test_bm25_retrieval_with_legacy_filters( - self, document_store: OpenSearchDocumentStore - ): + def test_bm25_retrieval_with_legacy_filters(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ Document( @@ -724,9 +643,7 @@ def test_bm25_retrieval_with_legacy_filters( retrieved_ids = sorted([doc.id for doc in res]) assert retrieved_ids == ["1", "2", "3", "4", "5"] - def test_bm25_retrieval_with_custom_query( - self, document_store: OpenSearchDocumentStore - ): + def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ Document( @@ -790,18 +707,8 @@ def test_bm25_retrieval_with_custom_query( custom_query = { "query": { "function_score": { - "query": { - "bool": { - "must": {"match": {"content": "$query"}}, - "filter": "$filters", - } - }, - "field_value_factor": { - "field": "likes", - "factor": 0.1, - "modifier": "log1p", - "missing": 0, - }, + "query": {"bool": {"must": {"match": {"content": "$query"}}, "filter": "$filters"}}, + "field_value_factor": {"field": "likes", "factor": 0.1, "modifier": "log1p", "missing": 0}, } } } @@ -817,15 +724,11 @@ def test_bm25_retrieval_with_custom_query( assert "2" == res[1].id assert "3" == res[2].id - def test_embedding_retrieval( - self, document_store_embedding_dim_4: OpenSearchDocumentStore - ): + def test_embedding_retrieval(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] - ), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), ] document_store_embedding_dim_4.write_documents(docs) results = document_store_embedding_dim_4._embedding_retrieval( @@ -835,9 +738,7 @@ def test_embedding_retrieval( assert results[0].content == "Most similar document" assert results[1].content == "2nd best document" - def test_embedding_retrieval_with_filters( - self, document_store_embedding_dim_4: OpenSearchDocumentStore - ): + def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), @@ -858,9 +759,7 @@ def test_embedding_retrieval_with_filters( assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" - def test_embedding_retrieval_with_legacy_filters( - self, document_store_embedding_dim_4: OpenSearchDocumentStore - ): + def test_embedding_retrieval_with_legacy_filters(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), @@ -881,17 +780,13 @@ def test_embedding_retrieval_with_legacy_filters( assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" - def test_embedding_retrieval_pagination( - self, document_store_embedding_dim_4: OpenSearchDocumentStore - ): + def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. """ docs = [ - Document( - content=f"Document {i}", embedding=[random.random() for _ in range(4)] - ) # noqa: S311 + Document(content=f"Document {i}", embedding=[random.random() for _ in range(4)]) # noqa: S311 for i in range(20) ] @@ -901,9 +796,7 @@ def test_embedding_retrieval_pagination( ) assert len(results) == 11 - def test_embedding_retrieval_with_custom_query( - self, document_store_embedding_dim_4: OpenSearchDocumentStore - ): + def test_embedding_retrieval_with_custom_query(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), @@ -917,21 +810,13 @@ def test_embedding_retrieval_with_custom_query( custom_query = { "query": { - "bool": { - "must": [ - {"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}} - ], - "filter": "$filters", - } + "bool": {"must": [{"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}}], "filter": "$filters"} } } filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} results = document_store_embedding_dim_4._embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], - top_k=1, - filters=filters, - custom_query=custom_query, + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters=filters, custom_query=custom_query ) assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" @@ -946,9 +831,7 @@ def test_embedding_retrieval_query_documents_different_embedding_sizes( document_store_embedding_dim_4.write_documents(docs) with pytest.raises(RequestError): - document_store_embedding_dim_4._embedding_retrieval( - query_embedding=[0.1, 0.1] - ) + document_store_embedding_dim_4._embedding_retrieval(query_embedding=[0.1, 0.1]) def test_write_documents_different_embedding_sizes_fail( self, document_store_embedding_dim_4: OpenSearchDocumentStore @@ -965,9 +848,7 @@ def test_write_documents_different_embedding_sizes_fail( document_store_embedding_dim_4.write_documents(docs) @patch("haystack_integrations.document_stores.opensearch.document_store.bulk") - def test_write_documents_with_badly_formatted_bulk_errors( - self, mock_bulk, document_store - ): + def test_write_documents_with_badly_formatted_bulk_errors(self, mock_bulk, document_store): error = {"some_key": "some_value"} mock_bulk.return_value = ([], [error]) @@ -1010,9 +891,7 @@ def test_embedding_retrieval_but_dont_return_embeddings_for_embedding_retrieval( docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] - ), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), ] document_store_no_embbding_returned.write_documents(docs) results = document_store_no_embbding_returned._embedding_retrieval( @@ -1027,13 +906,9 @@ def test_embedding_retrieval_but_dont_return_embeddings_for_bm25_retrieval( docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9] - ), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), ] document_store_no_embbding_returned.write_documents(docs) - results = document_store_no_embbding_returned._bm25_retrieval( - "document", top_k=2 - ) + results = document_store_no_embbding_returned._bm25_retrieval("document", top_k=2) assert len(results) == 2 assert results[0].embedding is None diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 91449a0e3..e52a099c8 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -6,13 +6,9 @@ import pytest from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from haystack_integrations.components.retrievers.opensearch import ( - OpenSearchEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.opensearch import OpenSearchEmbeddingRetriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore -from haystack_integrations.document_stores.opensearch.document_store import ( - DEFAULT_MAX_CHUNK_BYTES, -) +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES def test_init_default(): @@ -23,9 +19,7 @@ def test_init_default(): assert retriever._top_k == 10 assert retriever._filter_policy == FilterPolicy.REPLACE - retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, filter_policy="replace" - ) + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): @@ -35,9 +29,7 @@ def test_init_default(): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchEmbeddingRetriever( - document_store=document_store, custom_query={"some": "custom query"} - ) + retriever = OpenSearchEmbeddingRetriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() type_s = "haystack_integrations.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever" assert res == { @@ -138,9 +130,7 @@ def test_from_dict(_mock_opensearch_client): def test_run(): mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._embedding_retrieval.return_value = [ - Document(content="Test doc", embedding=[0.1, 0.2]) - ] + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( @@ -157,14 +147,9 @@ def test_run(): def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._embedding_retrieval.return_value = [ - Document(content="Test doc", embedding=[0.1, 0.2]) - ] + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, - filters={"from": "init"}, - top_k=11, - custom_query="custom_query", + document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( @@ -181,12 +166,8 @@ def test_run_init_params(): def test_run_time_params(): mock_store = Mock(spec=OpenSearchDocumentStore) - mock_store._embedding_retrieval.return_value = [ - Document(content="Test doc", embedding=[0.1, 0.2]) - ] - retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, filters={"from": "init"}, top_k=11 - ) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], @@ -203,9 +184,7 @@ def test_run_time_params(): def test_run_ignore_errors(caplog): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.side_effect = Exception("Some error") - retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, raise_on_failure=False - ) + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, raise_on_failure=False) res = retriever.run(query_embedding=[0.5, 0.7]) assert len(res) == 1 assert res["documents"] == [] diff --git a/integrations/opensearch/tests/test_filters.py b/integrations/opensearch/tests/test_filters.py index d5607e9d1..d333dc584 100644 --- a/integrations/opensearch/tests/test_filters.py +++ b/integrations/opensearch/tests/test_filters.py @@ -3,10 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest from haystack.errors import FilterError -from haystack_integrations.document_stores.opensearch.filters import ( - _normalize_ranges, - normalize_filters, -) +from haystack_integrations.document_stores.opensearch.filters import _normalize_ranges, normalize_filters filters_data = [ ( @@ -17,16 +14,8 @@ { "operator": "OR", "conditions": [ - { - "field": "meta.genre", - "operator": "in", - "value": ["economy", "politics"], - }, - { - "field": "meta.publisher", - "operator": "==", - "value": "nytimes", - }, + {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, ], }, {"field": "meta.date", "operator": ">=", "value": "2015-01-01"}, @@ -75,22 +64,8 @@ { "bool": { "should": [ - { - "bool": { - "must": [ - {"term": {"Type": "News Paper"}}, - {"range": {"Date": {"lt": "2020-01-01"}}}, - ] - } - }, - { - "bool": { - "must": [ - {"term": {"Type": "Blog Post"}}, - {"range": {"Date": {"gte": "2019-01-01"}}}, - ] - } - }, + {"bool": {"must": [{"term": {"Type": "News Paper"}}, {"range": {"Date": {"lt": "2020-01-01"}}}]}}, + {"bool": {"must": [{"term": {"Type": "Blog Post"}}, {"range": {"Date": {"gte": "2019-01-01"}}}]}}, ] } }, @@ -106,16 +81,8 @@ { "operator": "OR", "conditions": [ - { - "field": "meta.genre", - "operator": "in", - "value": ["economy", "politics"], - }, - { - "field": "meta.publisher", - "operator": "==", - "value": "nytimes", - }, + {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, + {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, ], }, ], @@ -139,26 +106,8 @@ }, ), ( - { - "operator": "AND", - "conditions": [ - {"field": "text", "operator": "==", "value": "A Foo Document 1"} - ], - }, - { - "bool": { - "must": [ - { - "match": { - "text": { - "query": "A Foo Document 1", - "minimum_should_match": "100%", - } - } - } - ] - } - }, + {"operator": "AND", "conditions": [{"field": "text", "operator": "==", "value": "A Foo Document 1"}]}, + {"bool": {"must": [{"match": {"text": {"query": "A Foo Document 1", "minimum_should_match": "100%"}}}]}}, ), ( { @@ -177,14 +126,7 @@ { "bool": { "should": [ - { - "bool": { - "should": [ - {"term": {"name": "name_0"}}, - {"term": {"name": "name_1"}}, - ] - } - }, + {"bool": {"should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}]}}, {"range": {"number": {"lt": 1.0}}}, ] } @@ -199,14 +141,7 @@ {"field": "meta.name", "operator": "in", "value": ["name_0", "name_1"]}, ], }, - { - "bool": { - "must": [ - {"terms": {"name": ["name_0", "name_1"]}}, - {"range": {"number": {"lte": 2, "gte": 0}}}, - ] - } - }, + {"bool": {"must": [{"terms": {"name": ["name_0", "name_1"]}}, {"range": {"number": {"lte": 2, "gte": 0}}}]}}, ), ( { @@ -226,11 +161,7 @@ {"field": "meta.name", "operator": "==", "value": "name_1"}, ], }, - { - "bool": { - "should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}] - } - }, + {"bool": {"should": [{"term": {"name": "name_0"}}, {"term": {"name": "name_1"}}]}}, ), ( { @@ -240,20 +171,7 @@ {"field": "meta.name", "operator": "==", "value": "name_0"}, ], }, - { - "bool": { - "must_not": [ - { - "bool": { - "must": [ - {"term": {"number": 100}}, - {"term": {"name": "name_0"}}, - ] - } - } - ] - } - }, + {"bool": {"must_not": [{"bool": {"must": [{"term": {"number": 100}}, {"term": {"name": "name_0"}}]}}]}}, ), ] @@ -280,27 +198,15 @@ def test_normalize_filters_malformed(): # Missing comparison field with pytest.raises(FilterError): - normalize_filters( - {"operator": "AND", "conditions": [{"operator": "==", "value": "article"}]} - ) + normalize_filters({"operator": "AND", "conditions": [{"operator": "==", "value": "article"}]}) # Missing comparison operator with pytest.raises(FilterError): - normalize_filters( - { - "operator": "AND", - "conditions": [{"field": "meta.type", "operator": "=="}], - } - ) + normalize_filters({"operator": "AND", "conditions": [{"field": "meta.type", "operator": "=="}]}) # Missing comparison value with pytest.raises(FilterError): - normalize_filters( - { - "operator": "AND", - "conditions": [{"field": "meta.type", "value": "article"}], - } - ) + normalize_filters({"operator": "AND", "conditions": [{"field": "meta.type", "value": "article"}]}) def test_normalize_ranges(): diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py index ac4855e33..02e56b34c 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py @@ -2,17 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 -from .optimization import ( - OptimumEmbedderOptimizationConfig, - OptimumEmbedderOptimizationMode, -) +from .optimization import OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode from .optimum_document_embedder import OptimumDocumentEmbedder from .optimum_text_embedder import OptimumTextEmbedder from .pooling import OptimumEmbedderPooling -from .quantization import ( - OptimumEmbedderQuantizationConfig, - OptimumEmbedderQuantizationMode, -) +from .quantization import OptimumEmbedderQuantizationConfig, OptimumEmbedderQuantizationMode __all__ = [ "OptimumDocumentEmbedder", diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py index 2c4f7c10c..5a9e1cf1f 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py @@ -7,12 +7,7 @@ import numpy as np import torch from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.hf import ( - HFModelType, - check_valid_model, - deserialize_hf_model_kwargs, - serialize_hf_model_kwargs, -) +from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs from huggingface_hub import hf_hub_download from sentence_transformers.models import Pooling as SentenceTransformerPoolingLayer from tqdm import tqdm @@ -62,12 +57,8 @@ def serialize(self) -> Dict[str, Any]: assert isinstance(self.pooling_mode, OptimumEmbedderPooling) out["pooling_mode"] = str(self.pooling_mode) out["token"] = self.token.to_dict() if self.token else None - out["optimizer_settings"] = ( - self.optimizer_settings.to_dict() if self.optimizer_settings else None - ) - out["quantizer_settings"] = ( - self.quantizer_settings.to_dict() if self.quantizer_settings else None - ) + out["optimizer_settings"] = self.optimizer_settings.to_dict() if self.optimizer_settings else None + out["quantizer_settings"] = self.quantizer_settings.to_dict() if self.quantizer_settings else None out["model_kwargs"].pop("use_auth_token", None) serialize_hf_model_kwargs(out["model_kwargs"]) @@ -77,13 +68,9 @@ def serialize(self) -> Dict[str, Any]: def deserialize_inplace(cls, data: Dict[str, Any]) -> Dict[str, Any]: data["pooling_mode"] = OptimumEmbedderPooling.from_str(data["pooling_mode"]) if data["optimizer_settings"] is not None: - data["optimizer_settings"] = OptimumEmbedderOptimizationConfig.from_dict( - data["optimizer_settings"] - ) + data["optimizer_settings"] = OptimumEmbedderOptimizationConfig.from_dict(data["optimizer_settings"]) if data["quantizer_settings"] is not None: - data["quantizer_settings"] = OptimumEmbedderQuantizationConfig.from_dict( - data["quantizer_settings"] - ) + data["quantizer_settings"] = OptimumEmbedderQuantizationConfig.from_dict(data["quantizer_settings"]) deserialize_secrets_inplace(data, keys=["token"]) deserialize_hf_model_kwargs(data["model_kwargs"]) @@ -98,9 +85,7 @@ def __init__(self, params: _EmbedderParams): if isinstance(params.pooling_mode, str): params.pooling_mode = OptimumEmbedderPooling.from_str(params.pooling_mode) elif params.pooling_mode is None: - params.pooling_mode = _pooling_from_model_config( - params.model, resolved_token - ) + params.pooling_mode = _pooling_from_model_config(params.model, resolved_token) if params.pooling_mode is None: modes = {e.value: e for e in OptimumEmbedderPooling} @@ -130,9 +115,7 @@ def __init__(self, params: _EmbedderParams): def warm_up(self): assert self.params.model_kwargs model_kwargs = copy.deepcopy(self.params.model_kwargs) - model = ORTModelForFeatureExtraction.from_pretrained( - **model_kwargs, export=True - ) + model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True) # Model ID will be passed explicitly if optimization/quantization is enabled. model_kwargs.pop("model_id", None) @@ -142,12 +125,9 @@ def warm_up(self): assert self.params.working_dir optimizer = ORTOptimizer.from_pretrained(model) save_dir = optimizer.optimize( - save_dir=self.params.working_dir, - optimization_config=self.params.optimizer_settings.to_optimum_config(), - ) - model = ORTModelForFeatureExtraction.from_pretrained( - model_id=save_dir, **model_kwargs + save_dir=self.params.working_dir, optimization_config=self.params.optimizer_settings.to_optimum_config() ) + model = ORTModelForFeatureExtraction.from_pretrained(model_id=save_dir, **model_kwargs) optimized_model = True if self.params.quantizer_settings: @@ -157,23 +137,17 @@ def warm_up(self): # since Optimum expects no more than one ONXX model in the working directory. There's # a file name parameter, but the optimizer only returns the working directory. working_dir = ( - Path(self.params.working_dir) - if not optimized_model - else Path(self.params.working_dir) / "quantized" + Path(self.params.working_dir) if not optimized_model else Path(self.params.working_dir) / "quantized" ) quantizer = ORTQuantizer.from_pretrained(model) save_dir = quantizer.quantize( - save_dir=working_dir, - quantization_config=self.params.quantizer_settings.to_optimum_config(), - ) - model = ORTModelForFeatureExtraction.from_pretrained( - model_id=save_dir, **model_kwargs + save_dir=working_dir, quantization_config=self.params.quantizer_settings.to_optimum_config() ) + model = ORTModelForFeatureExtraction.from_pretrained(model_id=save_dir, **model_kwargs) self.model = model self.tokenizer = AutoTokenizer.from_pretrained( - self.params.model, - token=self.params.token.resolve_value() if self.params.token else None, + self.params.model, token=self.params.token.resolve_value() if self.params.token else None ) # We need the width of the embeddings to initialize the pooling layer @@ -184,32 +158,22 @@ def warm_up(self): self.pooling_layer = SentenceTransformerPoolingLayer( width, - pooling_mode_cls_token=self.params.pooling_mode - == OptimumEmbedderPooling.CLS, - pooling_mode_max_tokens=self.params.pooling_mode - == OptimumEmbedderPooling.MAX, - pooling_mode_mean_tokens=self.params.pooling_mode - == OptimumEmbedderPooling.MEAN, - pooling_mode_mean_sqrt_len_tokens=self.params.pooling_mode - == OptimumEmbedderPooling.MEAN_SQRT_LEN, - pooling_mode_weightedmean_tokens=self.params.pooling_mode - == OptimumEmbedderPooling.WEIGHTED_MEAN, - pooling_mode_lasttoken=self.params.pooling_mode - == OptimumEmbedderPooling.LAST_TOKEN, + pooling_mode_cls_token=self.params.pooling_mode == OptimumEmbedderPooling.CLS, + pooling_mode_max_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MAX, + pooling_mode_mean_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MEAN, + pooling_mode_mean_sqrt_len_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MEAN_SQRT_LEN, + pooling_mode_weightedmean_tokens=self.params.pooling_mode == OptimumEmbedderPooling.WEIGHTED_MEAN, + pooling_mode_lasttoken=self.params.pooling_mode == OptimumEmbedderPooling.LAST_TOKEN, ) - def _tokenize_and_generate_outputs( - self, texts: List[str] - ) -> Tuple[Dict[str, Any], BaseModelOutput]: + def _tokenize_and_generate_outputs(self, texts: List[str]) -> Tuple[Dict[str, Any], BaseModelOutput]: assert self.model is not None assert self.tokenizer is not None - tokenizer_outputs = self.tokenizer( - texts, padding=True, truncation=True, return_tensors="pt" - ).to(self.model.device) - model_inputs = { - k: v for k, v in tokenizer_outputs.items() if k in self.model.input_names - } + tokenizer_outputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( + self.model.device + ) + model_inputs = {k: v for k, v in tokenizer_outputs.items() if k in self.model.input_names} model_outputs = self.model(**model_inputs) return tokenizer_outputs, model_outputs @@ -217,9 +181,7 @@ def _tokenize_and_generate_outputs( def parameters(self) -> _EmbedderParams: return self.params - def pool_embeddings( - self, model_output: torch.Tensor, attention_mask: torch.Tensor - ) -> torch.Tensor: + def pool_embeddings(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: assert self.pooling_layer is not None features = {"token_embeddings": model_output, "attention_mask": attention_mask} pooled_outputs = self.pooling_layer.forward(features) @@ -251,9 +213,7 @@ def embed_texts( ): batch = sentences_sorted[i : i + self.params.batch_size] tokenizer_output, model_output = self._tokenize_and_generate_outputs(batch) - sentence_embeddings = self.pool_embeddings( - model_output[0], tokenizer_output["attention_mask"].to(device) - ) + sentence_embeddings = self.pool_embeddings(model_output[0], tokenizer_output["attention_mask"].to(device)) all_embeddings.append(sentence_embeddings) embeddings = torch.cat(all_embeddings, dim=0) @@ -274,13 +234,9 @@ def embed_texts( return reordered_embeddings -def _pooling_from_model_config( - model: str, token: Optional[str] = None -) -> Optional[OptimumEmbedderPooling]: +def _pooling_from_model_config(model: str, token: Optional[str] = None) -> Optional[OptimumEmbedderPooling]: try: - pooling_config_path = hf_hub_download( - repo_id=model, token=token, filename="1_Pooling/config.json" - ) + pooling_config_path = hf_hub_download(repo_id=model, token=token, filename="1_Pooling/config.json") except Exception as e: msg = f"An error occurred while downloading the model config: {e}" raise ValueError(msg) from e @@ -289,11 +245,7 @@ def _pooling_from_model_config( pooling_config = json.load(f) # Filter only those keys that start with "pooling_mode" and are True - true_pooling_modes = [ - key - for key, value in pooling_config.items() - if key.startswith("pooling_mode") and value - ] + true_pooling_modes = [key for key, value in pooling_config.items() if key.startswith("pooling_mode") and value] # If exactly one True pooling mode is found, return it # If no True pooling modes or more than one True pooling mode is found, return None diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index 4a4477cc1..27f533430 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -38,9 +38,7 @@ class OptimumDocumentEmbedder: def __init__( self, model: str = "sentence-transformers/all-mpnet-base-v2", - token: Optional[Secret] = Secret.from_env_var( - "HF_API_TOKEN", strict=False - ), # noqa: B008 + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 prefix: str = "", suffix: str = "", normalize_embeddings: bool = True, @@ -180,16 +178,12 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: texts_to_embed = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.meta_fields_to_embed - if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( self._backend.parameters.prefix - + self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) + + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._backend.parameters.suffix ) @@ -214,11 +208,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - if ( - not isinstance(documents, list) - or documents - and not isinstance(documents[0], Document) - ): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "OptimumDocumentEmbedder expects a list of Documents as input." " In case you want to embed a string, please use the OptimumTextEmbedder." diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py index e19f56a35..e3cffe183 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py @@ -34,9 +34,7 @@ class OptimumTextEmbedder: def __init__( self, model: str = "sentence-transformers/all-mpnet-base-v2", - token: Optional[Secret] = Secret.from_env_var( - "HF_API_TOKEN", strict=False - ), # noqa: B008 + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 prefix: str = "", suffix: str = "", normalize_embeddings: bool = True, @@ -180,8 +178,6 @@ def run(self, text: str): ) raise TypeError(msg) - text_to_embed = ( - self._backend.parameters.prefix + text + self._backend.parameters.suffix - ) + text_to_embed = self._backend.parameters.prefix + text + self._backend.parameters.suffix embedding = self._backend.embed_texts(text_to_embed) return {"embedding": embedding} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py index 2bf78070e..d45369544 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py @@ -66,21 +66,13 @@ def to_optimum_config(self) -> QuantizationConfig: Optimum configuration. """ if self.mode == OptimumEmbedderQuantizationMode.ARM64: - return AutoQuantizationConfig.arm64( - is_static=False, per_channel=self.per_channel - ) + return AutoQuantizationConfig.arm64(is_static=False, per_channel=self.per_channel) elif self.mode == OptimumEmbedderQuantizationMode.AVX2: - return AutoQuantizationConfig.avx2( - is_static=False, per_channel=self.per_channel - ) + return AutoQuantizationConfig.avx2(is_static=False, per_channel=self.per_channel) elif self.mode == OptimumEmbedderQuantizationMode.AVX512: - return AutoQuantizationConfig.avx512( - is_static=False, per_channel=self.per_channel - ) + return AutoQuantizationConfig.avx512(is_static=False, per_channel=self.per_channel) elif self.mode == OptimumEmbedderQuantizationMode.AVX512_VNNI: - return AutoQuantizationConfig.avx512_vnni( - is_static=False, per_channel=self.per_channel - ) + return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=self.per_channel) else: msg = f"Unknown quantization mode '{self.mode}'" raise ValueError(msg) diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index eeaa3d84a..9288bb688 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -6,9 +6,7 @@ from haystack.dataclasses import Document from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumDocumentEmbedder -from haystack_integrations.components.embedders.optimum.pooling import ( - OptimumEmbedderPooling, -) +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.optimization import ( OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode, @@ -39,26 +37,16 @@ def mock_get_pooling_mode(): class TestOptimumDocumentEmbedder: - def test_init_default( - self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode - ): # noqa: ARG002 + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumDocumentEmbedder() - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-mpnet-base-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "HF_API_TOKEN", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder._backend.parameters.prefix == "" assert embedder._backend.parameters.suffix == "" assert embedder._backend.parameters.normalize_embeddings is True - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CPUExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN assert embedder._backend.parameters.batch_size == 32 assert embedder._backend.parameters.progress_bar is True @@ -89,10 +77,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 quantizer_settings=None, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.token == Secret.from_token("fake-api-token") assert embedder._backend.parameters.prefix == "prefix" assert embedder._backend.parameters.suffix == "suffix" @@ -101,10 +86,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " assert embedder._backend.parameters.normalize_embeddings is False - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CUDAExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, @@ -116,9 +98,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict( - self, mock_check_valid_model, mock_get_pooling_mode - ): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumDocumentEmbedder() data = component.to_dict() @@ -126,11 +106,7 @@ def test_to_and_from_dict( "type": "haystack_integrations.components.embedders.optimum.optimum_document_embedder.OptimumDocumentEmbedder", "init_parameters": { "model": "sentence-transformers/all-mpnet-base-v2", - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "prefix": "", "suffix": "", "batch_size": 32, @@ -151,20 +127,12 @@ def test_to_and_from_dict( } embedder = OptimumDocumentEmbedder.from_dict(data) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-mpnet-base-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "HF_API_TOKEN", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder._backend.parameters.prefix == "" assert embedder._backend.parameters.suffix == "" assert embedder._backend.parameters.normalize_embeddings is True - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CPUExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN assert embedder._backend.parameters.batch_size == 32 assert embedder._backend.parameters.progress_bar is True @@ -196,9 +164,7 @@ def test_to_and_from_dict_with_custom_init_parameters( pooling_mode="max", model_kwargs={"trust_remote_code": True}, working_dir="working_dir", - optimizer_settings=OptimumEmbedderOptimizationConfig( - OptimumEmbedderOptimizationMode.O1, for_gpu=True - ), + optimizer_settings=OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=True), quantizer_settings=OptimumEmbedderQuantizationConfig( OptimumEmbedderQuantizationMode.ARM64, per_channel=True ), @@ -231,13 +197,8 @@ def test_to_and_from_dict_with_custom_init_parameters( } embedder = OptimumDocumentEmbedder.from_dict(data) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "ENV_VAR", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("ENV_VAR", strict=False) assert embedder._backend.parameters.prefix == "prefix" assert embedder._backend.parameters.suffix == "suffix" assert embedder._backend.parameters.batch_size == 64 @@ -245,10 +206,7 @@ def test_to_and_from_dict_with_custom_init_parameters( assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " assert embedder._backend.parameters.normalize_embeddings is False - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CUDAExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, @@ -257,17 +215,11 @@ def test_to_and_from_dict_with_custom_init_parameters( "use_auth_token": None, } assert embedder._backend.parameters.working_dir == "working_dir" - assert ( - embedder._backend.parameters.optimizer_settings - == OptimumEmbedderOptimizationConfig( - OptimumEmbedderOptimizationMode.O1, for_gpu=True - ) + assert embedder._backend.parameters.optimizer_settings == OptimumEmbedderOptimizationConfig( + OptimumEmbedderOptimizationMode.O1, for_gpu=True ) - assert ( - embedder._backend.parameters.quantizer_settings - == OptimumEmbedderQuantizationConfig( - OptimumEmbedderQuantizationMode.ARM64, per_channel=True - ) + assert embedder._backend.parameters.quantizer_settings == OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True ) def test_initialize_with_invalid_model(self, mock_check_valid_model): @@ -275,14 +227,11 @@ def test_initialize_with_invalid_model(self, mock_check_valid_model): with pytest.raises(RepositoryNotFoundError): OptimumDocumentEmbedder(model="invalid_model_id") - def test_initialize_with_invalid_pooling_mode( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") with pytest.raises(ValueError): OptimumDocumentEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - pooling_mode="Invalid_pooling_mode", + model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode" ) def test_infer_pooling_mode_from_str(self): @@ -296,16 +245,11 @@ def test_infer_pooling_mode_from_str(self): pooling_mode=pooling_mode.value, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.pooling_mode == pooling_mode @pytest.mark.integration - def test_default_pooling_mode_when_config_not_found( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 with pytest.raises(ValueError): OptimumDocumentEmbedder( model="embedding_model_finetuned", @@ -319,21 +263,12 @@ def test_infer_pooling_mode_from_hf(self): pooling_mode=None, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN - def test_prepare_texts_to_embed_w_metadata( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): # noqa: ARG002 documents = [ - Document( - content=f"document number {i}: content", - meta={"meta_field": f"meta_value {i}"}, - ) - for i in range(5) + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] embedder = OptimumDocumentEmbedder( @@ -353,9 +288,7 @@ def test_prepare_texts_to_embed_w_metadata( "meta_value 4 | document number 4: content", ] - def test_prepare_texts_to_embed_w_suffix( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): # noqa: ARG002 documents = [Document(content=f"document number {i}") for i in range(5)] embedder = OptimumDocumentEmbedder( @@ -376,24 +309,16 @@ def test_prepare_texts_to_embed_w_suffix( ] def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 - embedder = OptimumDocumentEmbedder( - model="sentence-transformers/all-mpnet-base-v2", pooling_mode="mean" - ) + embedder = OptimumDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="mean") embedder.warm_up() # wrong formats string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, - match="OptimumDocumentEmbedder expects a list of Documents as input", - ): + with pytest.raises(TypeError, match="OptimumDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=string_input) - with pytest.raises( - TypeError, - match="OptimumDocumentEmbedder expects a list of Documents as input", - ): + with pytest.raises(TypeError, match="OptimumDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 @@ -413,15 +338,10 @@ def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 [ (None, None), ( - OptimumEmbedderOptimizationConfig( - OptimumEmbedderOptimizationMode.O1, for_gpu=False - ), - None, - ), - ( + OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=False), None, - OptimumEmbedderQuantizationConfig(OptimumEmbedderQuantizationMode.AVX2), ), + (None, OptimumEmbedderQuantizationConfig(OptimumEmbedderQuantizationMode.AVX2)), # onxxruntime 1.17.x breaks support for quantizing optimized models. # c.f https://discuss.huggingface.co/t/optimize-and-quantize-with-optimum/23675/12 # ( @@ -433,13 +353,8 @@ def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 def test_run(self, opt_config, quant_config): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document( - content="A transformer is a deep learning architecture", - meta={"topic": "ML"}, - ), - Document( - content="Every planet we reach is dead", meta={"topic": "Monkeys"} - ), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + Document(content="Every planet we reach is dead", meta={"topic": "Monkeys"}), ] docs_copy = copy.deepcopy(docs) diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index 7b48d5010..ad0e7d800 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -3,9 +3,7 @@ import pytest from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumTextEmbedder -from haystack_integrations.components.embedders.optimum.pooling import ( - OptimumEmbedderPooling, -) +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.optimization import ( OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode, @@ -36,26 +34,16 @@ def mock_get_pooling_mode(): class TestOptimumTextEmbedder: - def test_init_default( - self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode - ): # noqa: ARG002 + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumTextEmbedder() - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-mpnet-base-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "HF_API_TOKEN", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder._backend.parameters.prefix == "" assert embedder._backend.parameters.suffix == "" assert embedder._backend.parameters.normalize_embeddings is True - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CPUExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", @@ -78,18 +66,12 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 quantizer_settings=None, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.token == Secret.from_token("fake-api-token") assert embedder._backend.parameters.prefix == "prefix" assert embedder._backend.parameters.suffix == "suffix" assert embedder._backend.parameters.normalize_embeddings is False - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CUDAExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, @@ -101,9 +83,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict( - self, mock_check_valid_model, mock_get_pooling_mode - ): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumTextEmbedder() data = component.to_dict() @@ -111,11 +91,7 @@ def test_to_and_from_dict( "type": "haystack_integrations.components.embedders.optimum.optimum_text_embedder.OptimumTextEmbedder", "init_parameters": { "model": "sentence-transformers/all-mpnet-base-v2", - "token": { - "env_vars": ["HF_API_TOKEN"], - "strict": False, - "type": "env_var", - }, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "prefix": "", "suffix": "", "normalize_embeddings": True, @@ -132,20 +108,12 @@ def test_to_and_from_dict( } embedder = OptimumTextEmbedder.from_dict(data) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-mpnet-base-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "HF_API_TOKEN", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder._backend.parameters.prefix == "" assert embedder._backend.parameters.suffix == "" assert embedder._backend.parameters.normalize_embeddings is True - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CPUExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", @@ -156,9 +124,7 @@ def test_to_and_from_dict( assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict_with_custom_init_parameters( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 component = OptimumTextEmbedder( model="sentence-transformers/all-minilm-l6-v2", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -169,9 +135,7 @@ def test_to_and_from_dict_with_custom_init_parameters( pooling_mode="max", model_kwargs={"trust_remote_code": True}, working_dir="working_dir", - optimizer_settings=OptimumEmbedderOptimizationConfig( - OptimumEmbedderOptimizationMode.O1, for_gpu=True - ), + optimizer_settings=OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=True), quantizer_settings=OptimumEmbedderQuantizationConfig( OptimumEmbedderQuantizationMode.ARM64, per_channel=True ), @@ -200,20 +164,12 @@ def test_to_and_from_dict_with_custom_init_parameters( } embedder = OptimumTextEmbedder.from_dict(data) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) - assert embedder._backend.parameters.token == Secret.from_env_var( - "ENV_VAR", strict=False - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("ENV_VAR", strict=False) assert embedder._backend.parameters.prefix == "prefix" assert embedder._backend.parameters.suffix == "suffix" assert embedder._backend.parameters.normalize_embeddings is False - assert ( - embedder._backend.parameters.onnx_execution_provider - == "CUDAExecutionProvider" - ) + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, @@ -222,17 +178,11 @@ def test_to_and_from_dict_with_custom_init_parameters( "use_auth_token": None, } assert embedder._backend.parameters.working_dir == "working_dir" - assert ( - embedder._backend.parameters.optimizer_settings - == OptimumEmbedderOptimizationConfig( - OptimumEmbedderOptimizationMode.O1, for_gpu=True - ) + assert embedder._backend.parameters.optimizer_settings == OptimumEmbedderOptimizationConfig( + OptimumEmbedderOptimizationMode.O1, for_gpu=True ) - assert ( - embedder._backend.parameters.quantizer_settings - == OptimumEmbedderQuantizationConfig( - OptimumEmbedderQuantizationMode.ARM64, per_channel=True - ) + assert embedder._backend.parameters.quantizer_settings == OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True ) def test_initialize_with_invalid_model(self, mock_check_valid_model): @@ -240,15 +190,10 @@ def test_initialize_with_invalid_model(self, mock_check_valid_model): with pytest.raises(RepositoryNotFoundError): OptimumTextEmbedder(model="invalid_model_id", pooling_mode="max") - def test_initialize_with_invalid_pooling_mode( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") with pytest.raises(ValueError): - OptimumTextEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - pooling_mode="Invalid_pooling_mode", - ) + OptimumTextEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode") def test_infer_pooling_mode_from_str(self): """ @@ -261,16 +206,11 @@ def test_infer_pooling_mode_from_str(self): pooling_mode=pooling_mode.value, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.pooling_mode == pooling_mode @pytest.mark.integration - def test_default_pooling_mode_when_config_not_found( - self, mock_check_valid_model - ): # noqa: ARG002 + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 with pytest.raises(ValueError): OptimumTextEmbedder( model="embedding_model_finetuned", @@ -284,10 +224,7 @@ def test_infer_pooling_mode_from_hf(self): pooling_mode=None, ) - assert ( - embedder._backend.parameters.model - == "sentence-transformers/all-minilm-l6-v2" - ) + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 @@ -300,9 +237,7 @@ def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 list_integers_input = [1, 2, 3] - with pytest.raises( - TypeError, match="OptimumTextEmbedder expects a string as an input" - ): + with pytest.raises(TypeError, match="OptimumTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) @pytest.mark.integration diff --git a/integrations/pgvector/examples/embedding_retrieval.py b/integrations/pgvector/examples/embedding_retrieval.py index bd57d5598..37ea88929 100644 --- a/integrations/pgvector/examples/embedding_retrieval.py +++ b/integrations/pgvector/examples/embedding_retrieval.py @@ -14,15 +14,10 @@ from haystack import Pipeline from haystack.components.converters import MarkdownToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter -from haystack_integrations.components.retrievers.pgvector import ( - PgvectorEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore # Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. @@ -43,9 +38,7 @@ indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) -indexing.add_component( - "splitter", DocumentSplitter(split_by="sentence", split_length=2) -) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing.add_component("embedder", SentenceTransformersDocumentEmbedder()) indexing.add_component("writer", DocumentWriter(document_store)) indexing.connect("converter", "splitter") @@ -57,9 +50,7 @@ # Create the querying Pipeline and try a query querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) -querying.add_component( - "retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3) -) +querying.add_component("retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) diff --git a/integrations/pgvector/examples/hybrid_retrieval.py b/integrations/pgvector/examples/hybrid_retrieval.py index e10ccd128..cee98fe08 100644 --- a/integrations/pgvector/examples/hybrid_retrieval.py +++ b/integrations/pgvector/examples/hybrid_retrieval.py @@ -14,17 +14,11 @@ from haystack import Pipeline from haystack.components.converters import MarkdownToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.joiners import DocumentJoiner from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter -from haystack_integrations.components.retrievers.pgvector import ( - PgvectorEmbeddingRetriever, - PgvectorKeywordRetriever, -) +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore # Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. @@ -45,9 +39,7 @@ indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) -indexing.add_component( - "splitter", DocumentSplitter(split_by="sentence", split_length=2) -) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing.add_component("document_embedder", SentenceTransformersDocumentEmbedder()) indexing.add_component("writer", DocumentWriter(document_store)) indexing.connect("converter", "splitter") @@ -59,13 +51,8 @@ # Create the querying Pipeline and try a query querying = Pipeline() querying.add_component("text_embedder", SentenceTransformersTextEmbedder()) -querying.add_component( - "retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3) -) -querying.add_component( - "keyword_retriever", - PgvectorKeywordRetriever(document_store=document_store, top_k=3), -) +querying.add_component("retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3)) +querying.add_component("keyword_retriever", PgvectorKeywordRetriever(document_store=document_store, top_k=3)) querying.add_component( "joiner", DocumentJoiner(join_mode="reciprocal_rank_fusion", top_k=3), @@ -75,9 +62,7 @@ querying.connect("retriever", "joiner") query = "cross-encoder" -results = querying.run( - {"text_embedder": {"text": query}, "keyword_retriever": {"query": query}} -) +results = querying.run({"text_embedder": {"text": query}, "keyword_retriever": {"query": query}}) for doc in results["joiner"]["documents"]: print(doc) diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 2cc12f156..22aab1a73 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -8,9 +8,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore -from haystack_integrations.document_stores.pgvector.document_store import ( - VALID_VECTOR_FUNCTIONS, -) +from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS @component @@ -65,9 +63,7 @@ def __init__( document_store: PgvectorDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - vector_function: Optional[ - Literal["cosine_similarity", "inner_product", "l2_distance"] - ] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ @@ -99,9 +95,7 @@ def __init__( self.top_k = top_k self.vector_function = vector_function or document_store.vector_function self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -131,15 +125,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": Deserialized component. """ doc_store_params = data["init_parameters"]["document_store"] - data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict( - doc_store_params - ) + data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -148,9 +138,7 @@ def run( query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, - vector_function: Optional[ - Literal["cosine_similarity", "inner_product", "l2_distance"] - ] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, ): """ Retrieve documents from the `PgvectorDocumentStore`, based on their embeddings. diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py index ec9adf95e..636471c31 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py @@ -70,9 +70,7 @@ def __init__( self.filters = filters or {} self.top_k = top_k self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -101,15 +99,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever": Deserialized component. """ doc_store_params = data["init_parameters"]["document_store"] - data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict( - doc_store_params - ) + data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 66c2b7d9d..ae4878aba 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -82,13 +82,9 @@ def __init__( table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, - vector_function: Literal[ - "cosine_similarity", "inner_product", "l2_distance" - ] = "cosine_similarity", + vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", recreate_table: bool = False, - search_strategy: Literal[ - "exact_nearest_neighbor", "hnsw" - ] = "exact_nearest_neighbor", + search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None, hnsw_index_name: str = "haystack_hnsw_index", @@ -181,9 +177,7 @@ def _create_connection(self): connection = connect(conn_str) connection.autocommit = True connection.execute("CREATE EXTENSION IF NOT EXISTS vector") - register_vector( - connection - ) # Note: this must be called before creating the cursors. + register_vector(connection) # Note: this must be called before creating the cursors. self._connection = connection self._cursor = self._connection.cursor() @@ -237,11 +231,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorDocumentStore": return default_from_dict(cls, data) def _execute_sql( - self, - sql_query: Query, - params: Optional[tuple] = None, - error_msg: str = "", - cursor: Optional[Cursor] = None, + self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None ): """ Internal method to execute SQL statements and handle exceptions. @@ -255,9 +245,7 @@ def _execute_sql( params = params or () cursor = cursor or self.cursor - sql_query_str = ( - sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query - ) + sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) try: @@ -275,13 +263,10 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), - embedding_dimension=SQLLiteral(self.embedding_dimension), + table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) ) - self._execute_sql( - create_sql, error_msg="Could not create table in PgvectorDocumentStore" - ) + self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") def delete_table(self): """ @@ -289,14 +274,9 @@ def delete_table(self): The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. """ - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format( - table_name=Identifier(self.table_name) - ) + delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - self._execute_sql( - delete_sql, - error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore", - ) + self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") def _create_keyword_index_if_not_exists(self): """ @@ -319,9 +299,7 @@ def _create_keyword_index_if_not_exists(self): ) if not index_exists: - self._execute_sql( - sql_create_index, error_msg="Could not create keyword index on table" - ) + self._execute_sql(sql_create_index, error_msg="Could not create keyword index on table") def _handle_hnsw(self): """ @@ -330,12 +308,10 @@ def _handle_hnsw(self): """ if self.hnsw_ef_search: - sql_set_hnsw_ef_search = SQL( - "SET hnsw.ef_search = {hnsw_ef_search}" - ).format(hnsw_ef_search=SQLLiteral(self.hnsw_ef_search)) - self._execute_sql( - sql_set_hnsw_ef_search, error_msg="Could not set hnsw.ef_search" + sql_set_hnsw_ef_search = SQL("SET hnsw.ef_search = {hnsw_ef_search}").format( + hnsw_ef_search=SQLLiteral(self.hnsw_ef_search) ) + self._execute_sql(sql_set_hnsw_ef_search, error_msg="Could not set hnsw.ef_search") index_exists = bool( self._execute_sql( @@ -353,9 +329,7 @@ def _handle_hnsw(self): ) return - sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=Identifier(self.hnsw_index_name) - ) + sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format(index_name=Identifier(self.hnsw_index_name)) self._execute_sql(sql_drop_index, error_msg="Could not drop HNSW index") self._create_hnsw_index() @@ -372,18 +346,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) " - ).format( - index_name=Identifier(self.hnsw_index_name), - table_name=Identifier(self.table_name), - ops=SQL(pg_ops), + sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( + index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) ) if actual_hnsw_index_creation_kwargs: actual_hnsw_index_creation_kwargs_str = ", ".join( - f"{key} = {value}" - for key, value in actual_hnsw_index_creation_kwargs.items() + f"{key} = {value}" for key, value in actual_hnsw_index_creation_kwargs.items() ) sql_add_creation_kwargs = SQL("WITH ({creation_kwargs_str})").format( creation_kwargs_str=SQL(actual_hnsw_index_creation_kwargs_str) @@ -397,18 +366,14 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format( - table_name=Identifier(self.table_name) - ) + sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) - count = self._execute_sql( - sql_count, error_msg="Could not count documents in PgvectorDocumentStore" - ).fetchone()[0] + count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ + 0 + ] return count - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -426,15 +391,11 @@ def filter_documents( if "operator" not in filters and "conditions" not in filters: filters = convert(filters) - sql_filter = SQL("SELECT * FROM {table_name}").format( - table_name=Identifier(self.table_name) - ) + sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) params = () if filters: - sql_where_clause, params = _convert_filters_to_where_clause_and_params( - filters - ) + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) sql_filter += sql_where_clause result = self._execute_sql( @@ -448,9 +409,7 @@ def filter_documents( docs = self._from_pg_to_haystack_documents(records) return docs - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes documents to the document store. @@ -463,9 +422,7 @@ def write_documents( if len(documents) > 0: if not isinstance(documents[0], Document): - msg = ( - "param 'documents' must contain a list of objects of type Document" - ) + msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(msg) if policy == DuplicatePolicy.NONE: @@ -473,9 +430,7 @@ def write_documents( db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format( - table_name=Identifier(self.table_name) - ) + sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -484,11 +439,7 @@ def write_documents( sql_insert += SQL(" RETURNING id") - sql_query_str = ( - sql_insert.as_string(self.cursor) - if not isinstance(sql_insert, str) - else sql_insert - ) + sql_query_str = sql_insert.as_string(self.cursor) if not isinstance(sql_insert, str) else sql_insert logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) try: @@ -516,9 +467,7 @@ def write_documents( return written_docs @staticmethod - def _from_haystack_to_pg_documents( - documents: List[Document], - ) -> List[Dict[str, Any]]: + def _from_haystack_to_pg_documents(documents: List[Document]) -> List[Dict[str, Any]]: """ Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert documents into the PgvectorDocumentStore. @@ -526,22 +475,14 @@ def _from_haystack_to_pg_documents( db_documents = [] for document in documents: - db_document = { - k: v - for k, v in document.to_dict(flatten=False).items() - if k not in ["score", "blob"] - } + db_document = {k: v for k, v in document.to_dict(flatten=False).items() if k not in ["score", "blob"]} blob = document.blob db_document["blob_data"] = blob.data if blob else None db_document["blob_meta"] = Jsonb(blob.meta) if blob and blob.meta else None - db_document["blob_mime_type"] = ( - blob.mime_type if blob and blob.mime_type else None - ) + db_document["blob_mime_type"] = blob.mime_type if blob and blob.mime_type else None - db_document["dataframe"] = ( - Jsonb(db_document["dataframe"]) if db_document["dataframe"] else None - ) + db_document["dataframe"] = Jsonb(db_document["dataframe"]) if db_document["dataframe"] else None db_document["meta"] = Jsonb(db_document["meta"]) if "sparse_embedding" in db_document: @@ -559,9 +500,7 @@ def _from_haystack_to_pg_documents( return db_documents @staticmethod - def _from_pg_to_haystack_documents( - documents: List[Dict[str, Any]] - ) -> List[Document]: + def _from_pg_to_haystack_documents(documents: List[Dict[str, Any]]) -> List[Document]: """ Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents. """ @@ -581,9 +520,7 @@ def _from_pg_to_haystack_documents( haystack_document = Document.from_dict(haystack_dict) if blob_data: - blob = ByteStream( - data=blob_data, meta=blob_meta, mime_type=blob_mime_type - ) + blob = ByteStream(data=blob_data, meta=blob_meta, mime_type=blob_mime_type) haystack_document.blob = blob haystack_documents.append(haystack_document) @@ -602,17 +539,11 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL( - "DELETE FROM {table_name} WHERE id IN ({document_ids_str})" - ).format( - table_name=Identifier(self.table_name), - document_ids_str=SQL(document_ids_str), + delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( + table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) ) - self._execute_sql( - delete_sql, - error_msg="Could not delete documents from PgvectorDocumentStore", - ) + self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") def _keyword_retrieval( self, @@ -643,16 +574,11 @@ def _keyword_retrieval( where_params = () sql_where_clause = SQL("") if filters: - ( - sql_where_clause, - where_params, - ) = _convert_filters_to_where_clause_and_params( + sql_where_clause, where_params = _convert_filters_to_where_clause_and_params( filters=filters, operator="AND" ) - sql_sort = SQL(" ORDER BY score DESC LIMIT {top_k}").format( - top_k=SQLLiteral(top_k) - ) + sql_sort = SQL(" ORDER BY score DESC LIMIT {top_k}").format(top_k=SQLLiteral(top_k)) sql_query = sql_select + sql_where_clause + sql_sort @@ -673,9 +599,7 @@ def _embedding_retrieval( *, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - vector_function: Optional[ - Literal["cosine_similarity", "inner_product", "l2_distance"] - ] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -703,21 +627,15 @@ def _embedding_retrieval( raise ValueError(msg) # the vector must be a string with this format: "'[3,1,2]'" - query_embedding_for_postgres = ( - f"'[{','.join(str(el) for el in query_embedding)}]'" - ) + query_embedding_for_postgres = f"'[{','.join(str(el) for el in query_embedding)}]'" # to compute the scores, we use the approach described in pgvector README: # https://github.com/pgvector/pgvector?tab=readme-ov-file#distances # cosine_similarity and inner_product are modified from the result of the operator if vector_function == "cosine_similarity": - score_definition = ( - f"1 - (embedding <=> {query_embedding_for_postgres}) AS score" - ) + score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score" elif vector_function == "inner_product": - score_definition = ( - f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score" - ) + score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score" elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" @@ -729,9 +647,7 @@ def _embedding_retrieval( sql_where_clause = SQL("") params = () if filters: - sql_where_clause, params = _convert_filters_to_where_clause_and_params( - filters - ) + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) # we always want to return the most similar documents first # so when using l2_distance, the sort order must be ASC diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py index 65e159551..d3604cfb3 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -217,9 +217,7 @@ def _less_than_equal(field: str, value: Any) -> Tuple[str, Any]: def _not_in(field: str, value: Any) -> Tuple[str, List]: if not isinstance(value, list): - msg = ( - f"{field}'s value must be a list when using 'not in' comparator in Pinecone" - ) + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) return f"{field} IS NULL OR {field} != ALL(%s)", [value] diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index dd1f806c8..b53589763 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -29,9 +29,7 @@ def document_store(request): @pytest.fixture def patches_for_unit_tests(): - with patch( - "haystack_integrations.document_stores.pgvector.document_store.connect" - ) as mock_connect, patch( + with patch("haystack_integrations.document_stores.pgvector.document_store.connect") as mock_connect, patch( "haystack_integrations.document_stores.pgvector.document_store.register_vector" ) as mock_register, patch( "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.delete_table" @@ -46,9 +44,7 @@ def patches_for_unit_tests(): @pytest.fixture -def mock_store( - patches_for_unit_tests, monkeypatch -): # noqa: ARG001 patches are not explicitly called but necessary +def mock_store(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches are not explicitly called but necessary monkeypatch.setenv("PG_CONN_STR", "some-connection-string") table_name = "haystack" embedding_dimension = 768 diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index de1d30a7a..eca8190ee 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -9,11 +9,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import ( - CountDocumentsTest, - DeleteDocumentsTest, - WriteDocumentsTest, -) +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest from haystack.utils import Secret from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore from pandas import DataFrame @@ -28,9 +24,7 @@ def test_write_documents(self, document_store: PgvectorDocumentStore): document_store.write_documents(docs, DuplicatePolicy.FAIL) def test_write_blob(self, document_store: PgvectorDocumentStore): - bytestream = ByteStream( - b"test", meta={"meta_key": "meta_value"}, mime_type="mime_type" - ) + bytestream = ByteStream(b"test", meta={"meta_key": "meta_value"}, mime_type="mime_type") docs = [Document(id="1", blob=bytestream)] document_store.write_documents(docs) @@ -70,10 +64,7 @@ def test_init(monkeypatch): assert document_store.recreate_table assert document_store.search_strategy == "hnsw" assert document_store.hnsw_recreate_index_if_exists - assert document_store.hnsw_index_creation_kwargs == { - "m": 32, - "ef_construction": 128, - } + assert document_store.hnsw_index_creation_kwargs == {"m": 32, "ef_construction": 128} assert document_store.hnsw_index_name == "my_hnsw_index" assert document_store.hnsw_ef_search == 50 assert document_store.keyword_index_name == "my_keyword_index" @@ -99,11 +90,7 @@ def test_to_dict(monkeypatch): assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", "embedding_dimension": 512, "vector_function": "l2_distance", @@ -137,11 +124,7 @@ def test_from_haystack_to_pg_documents(): ), Document( id="3", - blob=ByteStream( - b"test", - meta={"blob_meta_key": "blob_meta_value"}, - mime_type="mime_type", - ), + blob=ByteStream(b"test", meta={"blob_meta_key": "blob_meta_value"}, mime_type="mime_type"), meta={"meta_key": "meta_value"}, embedding=[0.7, 0.8, 0.9], score=0.7, @@ -168,10 +151,7 @@ def test_from_haystack_to_pg_documents(): assert pg_docs[1]["id"] == "2" assert pg_docs[1]["content"] is None - assert ( - pg_docs[1]["dataframe"].obj - == DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json() - ) + assert pg_docs[1]["dataframe"].obj == DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json() assert pg_docs[1]["blob_data"] is None assert pg_docs[1]["blob_meta"] is None assert pg_docs[1]["blob_mime_type"] is None @@ -237,9 +217,7 @@ def test_from_pg_to_haystack_documents(): assert haystack_docs[1].id == "2" assert haystack_docs[1].content is None - assert haystack_docs[1].dataframe.equals( - DataFrame({"col1": [1, 2], "col2": [3, 4]}) - ) + assert haystack_docs[1].dataframe.equals(DataFrame({"col1": [1, 2], "col2": [3, 4]})) assert haystack_docs[1].blob is None assert haystack_docs[1].meta == {"meta_key": "meta_value"} assert haystack_docs[1].embedding == [0.4, 0.5, 0.6] diff --git a/integrations/pgvector/tests/test_embedding_retrieval.py b/integrations/pgvector/tests/test_embedding_retrieval.py index 7a521d422..2c384f57c 100644 --- a/integrations/pgvector/tests/test_embedding_retrieval.py +++ b/integrations/pgvector/tests/test_embedding_retrieval.py @@ -33,140 +33,78 @@ def document_store_w_hnsw_index(self, request): store.delete_table() - @pytest.mark.parametrize( - "document_store", - ["document_store", "document_store_w_hnsw_index"], - indirect=True, - ) - def test_embedding_retrieval_cosine_similarity( - self, document_store: PgvectorDocumentStore - ): + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_cosine_similarity(self, document_store: PgvectorDocumentStore): query_embedding = [0.1] * 768 most_similar_embedding = [0.8] * 768 second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 another_embedding = rand(768).tolist() docs = [ - Document( - content="Most similar document (cosine sim)", - embedding=most_similar_embedding, - ), - Document( - content="2nd best document (cosine sim)", - embedding=second_best_embedding, - ), - Document( - content="Not very similar document (cosine sim)", - embedding=another_embedding, - ), + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), ] document_store.write_documents(docs) results = document_store._embedding_retrieval( - query_embedding=query_embedding, - top_k=2, - filters={}, - vector_function="cosine_similarity", + query_embedding=query_embedding, top_k=2, filters={}, vector_function="cosine_similarity" ) assert len(results) == 2 assert results[0].content == "Most similar document (cosine sim)" assert results[1].content == "2nd best document (cosine sim)" assert results[0].score > results[1].score - @pytest.mark.parametrize( - "document_store", - ["document_store", "document_store_w_hnsw_index"], - indirect=True, - ) - def test_embedding_retrieval_inner_product( - self, document_store: PgvectorDocumentStore - ): + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_inner_product(self, document_store: PgvectorDocumentStore): query_embedding = [0.1] * 768 most_similar_embedding = [0.8] * 768 second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 another_embedding = rand(768).tolist() docs = [ - Document( - content="Most similar document (inner product)", - embedding=most_similar_embedding, - ), - Document( - content="2nd best document (inner product)", - embedding=second_best_embedding, - ), - Document( - content="Not very similar document (inner product)", - embedding=another_embedding, - ), + Document(content="Most similar document (inner product)", embedding=most_similar_embedding), + Document(content="2nd best document (inner product)", embedding=second_best_embedding), + Document(content="Not very similar document (inner product)", embedding=another_embedding), ] document_store.write_documents(docs) results = document_store._embedding_retrieval( - query_embedding=query_embedding, - top_k=2, - filters={}, - vector_function="inner_product", + query_embedding=query_embedding, top_k=2, filters={}, vector_function="inner_product" ) assert len(results) == 2 assert results[0].content == "Most similar document (inner product)" assert results[1].content == "2nd best document (inner product)" assert results[0].score > results[1].score - @pytest.mark.parametrize( - "document_store", - ["document_store", "document_store_w_hnsw_index"], - indirect=True, - ) - def test_embedding_retrieval_l2_distance( - self, document_store: PgvectorDocumentStore - ): + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_l2_distance(self, document_store: PgvectorDocumentStore): query_embedding = [0.1] * 768 most_similar_embedding = [0.1] * 765 + [0.15] * 3 second_best_embedding = [0.1] * 700 + [0.1] * 3 + [0.2] * 65 another_embedding = rand(768).tolist() docs = [ - Document( - content="Most similar document (l2 dist)", - embedding=most_similar_embedding, - ), - Document( - content="2nd best document (l2 dist)", embedding=second_best_embedding - ), - Document( - content="Not very similar document (l2 dist)", - embedding=another_embedding, - ), + Document(content="Most similar document (l2 dist)", embedding=most_similar_embedding), + Document(content="2nd best document (l2 dist)", embedding=second_best_embedding), + Document(content="Not very similar document (l2 dist)", embedding=another_embedding), ] document_store.write_documents(docs) results = document_store._embedding_retrieval( - query_embedding=query_embedding, - top_k=2, - filters={}, - vector_function="l2_distance", + query_embedding=query_embedding, top_k=2, filters={}, vector_function="l2_distance" ) assert len(results) == 2 assert results[0].content == "Most similar document (l2 dist)" assert results[1].content == "2nd best document (l2 dist)" assert results[0].score < results[1].score - @pytest.mark.parametrize( - "document_store", - ["document_store", "document_store_w_hnsw_index"], - indirect=True, - ) - def test_embedding_retrieval_with_filters( - self, document_store: PgvectorDocumentStore - ): - docs = [ - Document(content=f"Document {i}", embedding=rand(768).tolist()) - for i in range(10) - ] + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [Document(content=f"Document {i}", embedding=rand(768).tolist()) for i in range(10)] for i in range(10): docs[i].meta["meta_field"] = "custom_value" if i % 2 == 0 else "other_value" @@ -174,15 +112,9 @@ def test_embedding_retrieval_with_filters( document_store.write_documents(docs) query_embedding = [0.1] * 768 - filters = { - "field": "meta.meta_field", - "operator": "==", - "value": "custom_value", - } + filters = {"field": "meta.meta_field", "operator": "==", "value": "custom_value"} - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=3, filters=filters - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=3, filters=filters) assert len(results) == 3 for result in results: assert result.meta["meta_field"] == "custom_value" @@ -193,9 +125,7 @@ def test_empty_query_embedding(self, document_store: PgvectorDocumentStore): with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=query_embedding) - def test_query_embedding_wrong_dimension( - self, document_store: PgvectorDocumentStore - ): + def test_query_embedding_wrong_dimension(self, document_store: PgvectorDocumentStore): query_embedding = [0.1] * 4 with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py index 49be9cc42..bda10e3c0 100644 --- a/integrations/pgvector/tests/test_filters.py +++ b/integrations/pgvector/tests/test_filters.py @@ -17,9 +17,7 @@ @pytest.mark.integration class TestFilters(FilterDocumentsTest): - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ This overrides the default assert_documents_are_equal from FilterDocumentsTest. It is needed because the embeddings are not exactly the same when they are retrieved from Postgres. @@ -39,8 +37,7 @@ def assert_documents_are_equal( assert received_doc == expected_doc @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") - def test_not_operator(self, document_store, filterable_docs): - ... + def test_not_operator(self, document_store, filterable_docs): ... def test_complex_filter(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) @@ -58,11 +55,7 @@ def test_complex_filter(self, document_store, filterable_docs): "operator": "AND", "conditions": [ {"field": "meta.page", "operator": "==", "value": "90"}, - { - "field": "meta.chapter", - "operator": "==", - "value": "conclusion", - }, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, ], }, ], @@ -76,43 +69,23 @@ def test_complex_filter(self, document_store, filterable_docs): d for d in filterable_docs if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") - or ( - d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion" - ) + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") ], ) def test_treat_meta_field(): - assert ( - _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" - ) - assert ( - _treat_meta_field(field="meta.number", value=[1, 2, 3]) - == "(meta->>'number')::integer" - ) + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" - assert ( - _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" - ) - assert ( - _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) - == "(meta->>'number')::real" - ) - assert ( - _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" - ) - assert ( - _treat_meta_field(field="meta.bool", value=[True, False, True]) - == "(meta->>'bool')::boolean" - ) + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" # do not cast the field if its value is not one of the known types, an empty list or None - assert ( - _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) - == "meta->>'other'" - ) + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" @@ -172,22 +145,14 @@ def test_logical_condition_nested(): "operator": "OR", "conditions": [ {"field": "meta.domain", "operator": "!=", "value": "science"}, - { - "field": "meta.chapter", - "operator": "in", - "value": ["intro", "conclusion"], - }, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, ], }, { "operator": "OR", "conditions": [ {"field": "meta.number", "operator": ">=", "value": 90}, - { - "field": "meta.author", - "operator": "not in", - "value": ["John", "Jane"], - }, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, ], }, ], @@ -209,9 +174,7 @@ def test_convert_filters_to_where_clause_and_params(): ], } where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL( - "((meta->>'number')::integer = %s AND meta->>'chapter' = %s)" - ) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") assert params == (100, "intro") @@ -224,7 +187,5 @@ def test_convert_filters_to_where_clause_and_params_handle_null(): ], } where_clause, params = _convert_filters_to_where_clause_and_params(filters) - assert where_clause == SQL(" WHERE ") + SQL( - "(meta->>'number' IS NULL AND meta->>'chapter' = %s)" - ) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") assert params == ("intro",) diff --git a/integrations/pgvector/tests/test_keyword_retrieval.py b/integrations/pgvector/tests/test_keyword_retrieval.py index 15096dc71..4a5614165 100644 --- a/integrations/pgvector/tests/test_keyword_retrieval.py +++ b/integrations/pgvector/tests/test_keyword_retrieval.py @@ -7,9 +7,7 @@ class TestKeywordRetrieval: def test_keyword_retrieval(self, document_store: PgvectorDocumentStore): docs = [ - Document( - content="The quick brown fox chased the dog", embedding=[0.1] * 768 - ), + Document(content="The quick brown fox chased the dog", embedding=[0.1] * 768), Document(content="The fox was brown", embedding=[0.1] * 768), Document(content="The lazy dog", embedding=[0.1] * 768), Document(content="fox fox fox", embedding=[0.1] * 768), @@ -25,39 +23,23 @@ def test_keyword_retrieval(self, document_store: PgvectorDocumentStore): assert results[0].id == docs[-1].id assert results[0].score > results[1].score - def test_keyword_retrieval_with_filters( - self, document_store: PgvectorDocumentStore - ): + def test_keyword_retrieval_with_filters(self, document_store: PgvectorDocumentStore): docs = [ Document( content="The quick brown fox chased the dog", embedding=([0.1] * 768), meta={"meta_field": "right_value"}, ), - Document( - content="The fox was brown", - embedding=([0.1] * 768), - meta={"meta_field": "right_value"}, - ), - Document( - content="The lazy dog", - embedding=([0.1] * 768), - meta={"meta_field": "right_value"}, - ), - Document( - content="fox fox fox", - embedding=([0.1] * 768), - meta={"meta_field": "wrong_value"}, - ), + Document(content="The fox was brown", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="The lazy dog", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="fox fox fox", embedding=([0.1] * 768), meta={"meta_field": "wrong_value"}), ] document_store.write_documents(docs) filters = {"field": "meta.meta_field", "operator": "==", "value": "right_value"} - results = document_store._keyword_retrieval( - query="fox", top_k=3, filters=filters - ) + results = document_store._keyword_retrieval(query="fox", top_k=3, filters=filters) assert len(results) == 2 for doc in results: assert "fox" in doc.content diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index a9afa7dd7..031c735fd 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -7,10 +7,7 @@ from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.pgvector import ( - PgvectorEmbeddingRetriever, - PgvectorKeywordRetriever, -) +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore @@ -23,22 +20,15 @@ def test_init_default(self, mock_store): assert retriever.filter_policy == FilterPolicy.REPLACE assert retriever.vector_function == mock_store.vector_function - retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, filter_policy="merge" - ) + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, filter_policy="merge") assert retriever.filter_policy == FilterPolicy.MERGE with pytest.raises(ValueError): - PgvectorEmbeddingRetriever( - document_store=mock_store, filter_policy="invalid" - ) + PgvectorEmbeddingRetriever(document_store=mock_store, filter_policy="invalid") def test_init(self, mock_store): retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, - filters={"field": "value"}, - top_k=5, - vector_function="l2_distance", + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" ) assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} @@ -48,10 +38,7 @@ def test_init(self, mock_store): def test_to_dict(self, mock_store): retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, - filters={"field": "value"}, - top_k=5, - vector_function="l2_distance", + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" ) res = retriever.to_dict() t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" @@ -61,11 +48,7 @@ def test_to_dict(self, mock_store): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -96,11 +79,7 @@ def test_from_dict(self, monkeypatch): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -146,16 +125,11 @@ def test_run(self): doc = Document(content="Test doc", embedding=[0.1, 0.2]) mock_store._embedding_retrieval.return_value = [doc] - retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, vector_function="l2_distance" - ) + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") res = retriever.run(query_embedding=[0.3, 0.5]) mock_store._embedding_retrieval.assert_called_once_with( - query_embedding=[0.3, 0.5], - filters={}, - top_k=10, - vector_function="l2_distance", + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" ) assert res == {"documents": [doc]} @@ -168,28 +142,21 @@ def test_init_default(self, mock_store): assert retriever.filters == {} assert retriever.top_k == 10 - retriever = PgvectorKeywordRetriever( - document_store=mock_store, filter_policy="merge" - ) + retriever = PgvectorKeywordRetriever(document_store=mock_store, filter_policy="merge") assert retriever.filter_policy == FilterPolicy.MERGE with pytest.raises(ValueError): PgvectorKeywordRetriever(document_store=mock_store, filter_policy="invalid") def test_init(self, mock_store): - retriever = PgvectorKeywordRetriever( - document_store=mock_store, filters={"field": "value"}, top_k=5 - ) + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 def test_init_with_filter_policy(self, mock_store): retriever = PgvectorKeywordRetriever( - document_store=mock_store, - filters={"field": "value"}, - top_k=5, - filter_policy=FilterPolicy.MERGE, + document_store=mock_store, filters={"field": "value"}, top_k=5, filter_policy=FilterPolicy.MERGE ) assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} @@ -197,9 +164,7 @@ def test_init_with_filter_policy(self, mock_store): assert retriever.filter_policy == FilterPolicy.MERGE def test_to_dict(self, mock_store): - retriever = PgvectorKeywordRetriever( - document_store=mock_store, filters={"field": "value"}, top_k=5 - ) + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) res = retriever.to_dict() t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" assert res == { @@ -208,11 +173,7 @@ def test_to_dict(self, mock_store): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -242,11 +203,7 @@ def test_from_dict(self, monkeypatch): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -294,11 +251,7 @@ def test_from_dict_without_filter_policy(self, monkeypatch): "document_store": { "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { - "connection_string": { - "env_vars": ["PG_CONN_STR"], - "strict": True, - "type": "env_var", - }, + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -344,9 +297,7 @@ def test_run(self): retriever = PgvectorKeywordRetriever(document_store=mock_store) res = retriever.run(query="test query") - mock_store._keyword_retrieval.assert_called_once_with( - query="test query", filters={}, top_k=10 - ) + mock_store._keyword_retrieval.assert_called_once_with(query="test query", filters={}, top_k=10) assert res == {"documents": [doc]} @@ -356,9 +307,7 @@ def test_run_with_filters(self): mock_store._keyword_retrieval.return_value = [doc] retriever = PgvectorKeywordRetriever( - document_store=mock_store, - filter_policy=FilterPolicy.MERGE, - filters={"field": "value"}, + document_store=mock_store, filter_policy=FilterPolicy.MERGE, filters={"field": "value"} ) res = retriever.run(query="test query", filters={"field2": "value2"}) diff --git a/integrations/pinecone/examples/example.py b/integrations/pinecone/examples/example.py index 9fa4d8541..5f7d92ce5 100644 --- a/integrations/pinecone/examples/example.py +++ b/integrations/pinecone/examples/example.py @@ -12,17 +12,12 @@ from haystack import Pipeline from haystack.components.converters import MarkdownToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter from haystack.utils import Secret -from haystack_integrations.components.retrievers.pinecone import ( - PineconeEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever from haystack_integrations.document_stores.pinecone import PineconeDocumentStore file_paths = glob.glob("neural-search-pills/pills/*.md") @@ -37,9 +32,7 @@ indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) -indexing.add_component( - "splitter", DocumentSplitter(split_by="sentence", split_length=2) -) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing.add_component("embedder", SentenceTransformersDocumentEmbedder()) indexing.add_component("writer", DocumentWriter(document_store)) indexing.connect("converter", "splitter") @@ -53,9 +46,7 @@ querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) -querying.add_component( - "retriever", PineconeEmbeddingRetriever(document_store=document_store, top_k=3) -) +querying.add_component("retriever", PineconeEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") results = querying.run({"embedder": {"text": "What is Question Answering?"}}) diff --git a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py index 2d1dc2311..76f781f97 100644 --- a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py +++ b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py @@ -75,9 +75,7 @@ def __init__( self.filters = filters or {} self.top_k = top_k self.filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -109,9 +107,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 7c7500bd2..1fd3adf40 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -84,16 +84,9 @@ def index(self): client = Pinecone(api_key=self.api_key.resolve_value(), source_tag="haystack") if self.index_name not in client.list_indexes().names(): - logger.info( - f"Index {self.index_name} does not exist. Creating a new index." - ) + logger.info(f"Index {self.index_name} does not exist. Creating a new index.") pinecone_spec = self._convert_dict_spec_to_pinecone_object(self.spec) - client.create_index( - name=self.index_name, - dimension=self.dimension, - spec=pinecone_spec, - metric=self.metric, - ) + client.create_index(name=self.index_name, dimension=self.dimension, spec=pinecone_spec, metric=self.metric) else: logger.info( f"Connecting to existing index {self.index_name}. `dimension`, `spec`, and `metric` will be ignored." @@ -164,16 +157,12 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ try: - count = self.index.describe_index_stats()["namespaces"][self.namespace][ - "vector_count" - ] + count = self.index.describe_index_stats()["namespaces"][self.namespace]["vector_count"] except KeyError: count = 0 return count - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes Documents to Pinecone. @@ -195,18 +184,12 @@ def write_documents( documents_for_pinecone = self._convert_documents_to_pinecone_format(documents) - result = self.index.upsert( - vectors=documents_for_pinecone, - namespace=self.namespace, - batch_size=self.batch_size, - ) + result = self.index.upsert(vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size) written_docs = result["upserted_count"] return written_docs - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -219,9 +202,7 @@ def filter_documents( # Pinecone only performs vector similarity search # here we are querying with a dummy vector and the max compatible top_k - documents = self._embedding_retrieval( - query_embedding=self._dummy_vector, filters=filters, top_k=TOP_K_LIMIT - ) + documents = self._embedding_retrieval(query_embedding=self._dummy_vector, filters=filters, top_k=TOP_K_LIMIT) # when simply filtering, we don't want to return any scores # furthermore, we are querying with a dummy vector, so the scores are meaningless @@ -285,9 +266,7 @@ def _embedding_retrieval( return self._convert_query_result_to_documents(result) - def _convert_query_result_to_documents( - self, query_result: Dict[str, Any] - ) -> List[Document]: + def _convert_query_result_to_documents(self, query_result: Dict[str, Any]) -> List[Document]: pinecone_docs = query_result["matches"] documents = [] for pinecone_doc in pinecone_docs: @@ -316,9 +295,7 @@ def _convert_query_result_to_documents( return documents - def _convert_documents_to_pinecone_format( - self, documents: List[Document] - ) -> List[Dict[str, Any]]: + def _convert_documents_to_pinecone_format(self, documents: List[Document]) -> List[Dict[str, Any]]: documents_for_pinecone = [] for document in documents: embedding = copy(document.embedding) @@ -328,11 +305,7 @@ def _convert_documents_to_pinecone_format( "A dummy embedding will be used, but this can affect the search results. " ) embedding = self._dummy_vector - doc_for_pinecone = { - "id": document.id, - "values": embedding, - "metadata": dict(document.meta), - } + doc_for_pinecone = {"id": document.id, "values": embedding, "metadata": dict(document.meta)} # we save content/dataframe as metadata if document.content is not None: @@ -346,10 +319,7 @@ def _convert_documents_to_pinecone_format( "objects in Pinecone is not supported. " "The content of the `blob` field will be ignored." ) - if ( - hasattr(document, "sparse_embedding") - and document.sparse_embedding is not None - ): + if hasattr(document, "sparse_embedding") and document.sparse_embedding is not None: logger.warning( "Document %s has the `sparse_embedding` field set," "but storing sparse embeddings in Pinecone is not currently supported." diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py index bc642db1f..2ddb26d61 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py @@ -141,9 +141,7 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: def _not_in(field: str, value: Any) -> Dict[str, Any]: if not isinstance(value, list): - msg = ( - f"{field}'s value must be a list when using 'not in' comparator in Pinecone" - ) + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) supported_types = (int, float, str) diff --git a/integrations/pinecone/tests/test_document_store.py b/integrations/pinecone/tests/test_document_store.py index ca91900eb..90ce2ccff 100644 --- a/integrations/pinecone/tests/test_document_store.py +++ b/integrations/pinecone/tests/test_document_store.py @@ -5,11 +5,7 @@ import numpy as np import pytest from haystack import Document -from haystack.testing.document_store import ( - CountDocumentsTest, - DeleteDocumentsTest, - WriteDocumentsTest, -) +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest from haystack.utils import Secret from pinecone import Pinecone, PodSpec, ServerlessSpec @@ -24,9 +20,7 @@ def test_init_is_lazy(_mock_client): @patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_init(mock_pinecone): - mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = { - "dimension": 60 - } + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60} document_store = PineconeDocumentStore( api_key=Secret.from_token("fake-api-key"), @@ -69,9 +63,7 @@ def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): @patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_to_from_dict(mock_pinecone, monkeypatch): - mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = { - "dimension": 60 - } + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60} monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") document_store = PineconeDocumentStore( index="my_index", @@ -105,17 +97,13 @@ def test_to_from_dict(mock_pinecone, monkeypatch): assert document_store.to_dict() == dict_output document_store = PineconeDocumentStore.from_dict(dict_output) - assert document_store.api_key == Secret.from_env_var( - "PINECONE_API_KEY", strict=True - ) + assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) assert document_store.index_name == "my_index" assert document_store.namespace == "test" assert document_store.batch_size == 50 assert document_store.dimension == 60 assert document_store.metric == "euclidean" - assert document_store.spec == { - "serverless": {"region": "us-east-1", "cloud": "aws"} - } + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} def test_init_fails_wo_api_key(monkeypatch): @@ -128,27 +116,15 @@ def test_init_fails_wo_api_key(monkeypatch): def test_convert_dict_spec_to_pinecone_object_serverless(): dict_spec = {"serverless": {"region": "us-east-1", "cloud": "aws"}} - pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object( - dict_spec - ) + pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) assert isinstance(pinecone_object, ServerlessSpec) assert pinecone_object.region == "us-east-1" assert pinecone_object.cloud == "aws" def test_convert_dict_spec_to_pinecone_object_pod(): - dict_spec = { - "pod": { - "replicas": 1, - "shards": 1, - "pods": 1, - "pod_type": "p1.x1", - "environment": "us-west1-gcp", - } - } - pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object( - dict_spec - ) + dict_spec = {"pod": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"}} + pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) assert isinstance(pinecone_object, PodSpec) assert pinecone_object.replicas == 1 @@ -160,22 +136,14 @@ def test_convert_dict_spec_to_pinecone_object_pod(): def test_convert_dict_spec_to_pinecone_object_fail(): dict_spec = { - "strange_key": { - "replicas": 1, - "shards": 1, - "pods": 1, - "pod_type": "p1.x1", - "environment": "us-west1-gcp", - } + "strange_key": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"} } with pytest.raises(ValueError): PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) @pytest.mark.integration -@pytest.mark.skipif( - "PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set" -) +@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") def test_serverless_index_creation_from_scratch(sleep_time): # we use a fixed index name to avoid hitting the limit of Pinecone's free tier (max 5 indexes) # the index name is defined in the test matrix of the GitHub Actions workflow @@ -215,42 +183,25 @@ def test_serverless_index_creation_from_scratch(sleep_time): @pytest.mark.integration -@pytest.mark.skipif( - "PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set" -) +@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest): def test_write_documents(self, document_store: PineconeDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 @pytest.mark.xfail( - run=True, - reason="Pinecone supports overwriting by default, but it takes a while for it to take effect", + run=True, reason="Pinecone supports overwriting by default, but it takes a while for it to take effect" ) - def test_write_documents_duplicate_overwrite( - self, document_store: PineconeDocumentStore - ): - ... + def test_write_documents_duplicate_overwrite(self, document_store: PineconeDocumentStore): ... @pytest.mark.skip(reason="Pinecone only supports UPSERT operations") - def test_write_documents_duplicate_fail( - self, document_store: PineconeDocumentStore - ): - ... + def test_write_documents_duplicate_fail(self, document_store: PineconeDocumentStore): ... @pytest.mark.skip(reason="Pinecone only supports UPSERT operations") - def test_write_documents_duplicate_skip( - self, document_store: PineconeDocumentStore - ): - ... + def test_write_documents_duplicate_skip(self, document_store: PineconeDocumentStore): ... - @pytest.mark.skip( - reason="Pinecone creates a namespace only when the first document is written" - ) - def test_delete_documents_empty_document_store( - self, document_store: PineconeDocumentStore - ): - ... + @pytest.mark.skip(reason="Pinecone creates a namespace only when the first document is written") + def test_delete_documents_empty_document_store(self, document_store: PineconeDocumentStore): ... def test_embedding_retrieval(self, document_store: PineconeDocumentStore): query_embedding = [0.1] * 768 @@ -266,9 +217,7 @@ def test_embedding_retrieval(self, document_store: PineconeDocumentStore): document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=query_embedding, top_k=2, filters={} - ) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Most similar document" assert results[1].content == "2nd best document" diff --git a/integrations/pinecone/tests/test_embedding_retriever.py b/integrations/pinecone/tests/test_embedding_retriever.py index 62347ea48..99be75982 100644 --- a/integrations/pinecone/tests/test_embedding_retriever.py +++ b/integrations/pinecone/tests/test_embedding_retriever.py @@ -8,9 +8,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.utils import Secret -from haystack_integrations.components.retrievers.pinecone import ( - PineconeEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever from haystack_integrations.document_stores.pinecone import PineconeDocumentStore @@ -22,9 +20,7 @@ def test_init_default(): assert retriever.top_k == 10 assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = PineconeEmbeddingRetriever( - document_store=mock_store, filter_policy="replace" - ) + retriever = PineconeEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever.filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): @@ -34,9 +30,7 @@ def test_init_default(): @patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_to_dict(mock_pinecone, monkeypatch): monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") - mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = { - "dimension": 512 - } + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} document_store = PineconeDocumentStore( index="default", namespace="test-namespace", @@ -102,24 +96,18 @@ def test_from_dict(mock_pinecone, monkeypatch): }, } - mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = { - "dimension": 512 - } + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} monkeypatch.setenv("PINECONE_API_KEY", "test-key") retriever = PineconeEmbeddingRetriever.from_dict(data) document_store = retriever.document_store - assert document_store.api_key == Secret.from_env_var( - "PINECONE_API_KEY", strict=True - ) + assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) assert document_store.index_name == "default" assert document_store.namespace == "test-namespace" assert document_store.batch_size == 50 assert document_store.dimension == 512 assert document_store.metric == "cosine" - assert document_store.spec == { - "serverless": {"region": "us-east-1", "cloud": "aws"} - } + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} assert retriever.filters == {} assert retriever.top_k == 10 @@ -154,24 +142,18 @@ def test_from_dict_no_filter_policy(mock_pinecone, monkeypatch): }, } - mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = { - "dimension": 512 - } + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} monkeypatch.setenv("PINECONE_API_KEY", "test-key") retriever = PineconeEmbeddingRetriever.from_dict(data) document_store = retriever.document_store - assert document_store.api_key == Secret.from_env_var( - "PINECONE_API_KEY", strict=True - ) + assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) assert document_store.index_name == "default" assert document_store.namespace == "test-namespace" assert document_store.batch_size == 50 assert document_store.dimension == 512 assert document_store.metric == "cosine" - assert document_store.spec == { - "serverless": {"region": "us-east-1", "cloud": "aws"} - } + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} assert retriever.filters == {} assert retriever.top_k == 10 @@ -180,9 +162,7 @@ def test_from_dict_no_filter_policy(mock_pinecone, monkeypatch): def test_run(): mock_store = Mock(spec=PineconeDocumentStore) - mock_store._embedding_retrieval.return_value = [ - Document(content="Test doc", embedding=[0.1, 0.2]) - ] + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = PineconeEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py index bed4dce1f..40c9cdb10 100644 --- a/integrations/pinecone/tests/test_filters.py +++ b/integrations/pinecone/tests/test_filters.py @@ -9,13 +9,9 @@ @pytest.mark.integration -@pytest.mark.skipif( - "PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set" -) +@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") class TestFilters(FilterDocumentsTest): - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): for doc in received: # Pinecone seems to convert integers to floats (undocumented behavior) # We convert them back to integers to compare them @@ -40,75 +36,47 @@ def assert_documents_are_equal( assert received_doc.embedding == pytest.approx(expected_doc.embedding) @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_not_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_greater_than_with_iso_date( - self, document_store, filterable_docs - ): - ... + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_greater_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_greater_than_equal_with_iso_date( - self, document_store, filterable_docs - ): - ... + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_greater_than_equal_with_none( - self, document_store, filterable_docs - ): - ... + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_less_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_less_than_equal_with_iso_date( - self, document_store, filterable_docs - ): - ... + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_less_than_equal_with_none( - self, document_store, filterable_docs - ): - ... + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support the 'not' operator") - def test_not_operator(self, document_store, filterable_docs): - ... + def test_not_operator(self, document_store, filterable_docs): ... # the following skipped tests should be reworked when Pinecone introduces a better handling of null values # see https://github.com/deepset-ai/haystack-core-integrations/issues/590 - @pytest.mark.skip( - reason="Pinecone does not include null values in the result of the $ne operator" - ) - def test_comparison_not_equal(self, document_store, filterable_docs): - ... + @pytest.mark.skip(reason="Pinecone does not include null values in the result of the $ne operator") + def test_comparison_not_equal(self, document_store, filterable_docs): ... - @pytest.mark.skip( - reason="Pinecone does not include null values in the result of the $ne operator" - ) - def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - ... + @pytest.mark.skip(reason="Pinecone does not include null values in the result of the $ne operator") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip( reason="Pinecone has inconsistent behavior with respect to other Document Stores with the $or operator" ) - def test_or_operator(self, document_store, filterable_docs): - ... + def test_or_operator(self, document_store, filterable_docs): ... diff --git a/integrations/qdrant/examples/embedding_retrieval.py b/integrations/qdrant/examples/embedding_retrieval.py index fd8a68878..f009191e7 100644 --- a/integrations/qdrant/examples/embedding_retrieval.py +++ b/integrations/qdrant/examples/embedding_retrieval.py @@ -9,10 +9,7 @@ from haystack import Pipeline from haystack.components.converters import MarkdownToDocument -from haystack.components.embedders import ( - SentenceTransformersDocumentEmbedder, - SentenceTransformersTextEmbedder, -) +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever @@ -33,9 +30,7 @@ indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) -indexing.add_component( - "splitter", DocumentSplitter(split_by="sentence", split_length=2) -) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing.add_component("embedder", SentenceTransformersDocumentEmbedder()) indexing.add_component("writer", DocumentWriter(document_store)) indexing.connect("converter", "splitter") @@ -47,9 +42,7 @@ # Create the querying Pipeline and try a query querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) -querying.add_component( - "retriever", QdrantEmbeddingRetriever(document_store=document_store, top_k=3) -) +querying.add_component("retriever", QdrantEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py index 55f1fd550..ed6422bfe 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py @@ -2,14 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .retriever import ( - QdrantEmbeddingRetriever, - QdrantHybridRetriever, - QdrantSparseEmbeddingRetriever, -) +from .retriever import QdrantEmbeddingRetriever, QdrantHybridRetriever, QdrantSparseEmbeddingRetriever -__all__ = ( - "QdrantEmbeddingRetriever", - "QdrantSparseEmbeddingRetriever", - "QdrantHybridRetriever", -) +__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever", "QdrantHybridRetriever") diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 833bd4a85..275a46f95 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -71,9 +71,7 @@ def __init__( self._scale_score = scale_score self._return_embedding = return_embedding self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold @@ -108,16 +106,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever": :returns: Deserialized component. """ - document_store = QdrantDocumentStore.from_dict( - data["init_parameters"]["document_store"] - ) + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -221,9 +215,7 @@ def __init__( self._scale_score = scale_score self._return_embedding = return_embedding self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold @@ -258,16 +250,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantSparseEmbeddingRetriever": :returns: Deserialized component. """ - document_store = QdrantDocumentStore.from_dict( - data["init_parameters"]["document_store"] - ) + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -379,9 +367,7 @@ def __init__( self._top_k = top_k self._return_embedding = return_embedding self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold @@ -412,16 +398,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantHybridRetriever": :returns: Deserialized component. """ - document_store = QdrantDocumentStore.from_dict( - data["init_parameters"]["document_store"] - ) + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py index a43b84ac8..01645a999 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py @@ -60,9 +60,7 @@ def convert_id(_id: str) -> str: QdrantPoint = Union[rest.ScoredPoint, rest.Record] -def convert_qdrant_point_to_haystack_document( - point: QdrantPoint, use_sparse_embeddings: bool -) -> Document: +def convert_qdrant_point_to_haystack_document(point: QdrantPoint, use_sparse_embeddings: bool) -> Document: payload = {**point.payload} payload["score"] = point.score if hasattr(point, "score") else None diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 58a265591..d55cbd71c 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -319,19 +319,11 @@ def filter_documents( :param filters: The filters to apply to the document list. :returns: A list of documents that match the given filters. """ - if ( - filters - and not isinstance(filters, dict) - and not isinstance(filters, rest.Filter) - ): + if filters and not isinstance(filters, dict) and not isinstance(filters, rest.Filter): msg = "Filter must be a dictionary or an instance of `qdrant_client.http.models.Filter`" raise ValueError(msg) - if ( - filters - and not isinstance(filters, rest.Filter) - and "operator" not in filters - ): + if filters and not isinstance(filters, rest.Filter) and "operator" not in filters: filters = convert_legacy_filters(filters) return list( self.get_documents_generator( @@ -362,18 +354,11 @@ def write_documents( msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}." raise ValueError(msg) self._set_up_collection( - self.index, - self.embedding_dim, - False, - self.similarity, - self.use_sparse_embeddings, - self.sparse_idf, + self.index, self.embedding_dim, False, self.similarity, self.use_sparse_embeddings, self.sparse_idf ) if len(documents) == 0: - logger.warning( - "Calling QdrantDocumentStore.write_documents() with empty list" - ) + logger.warning("Calling QdrantDocumentStore.write_documents() with empty list") return document_objects = self._handle_duplicate_documents( @@ -382,12 +367,8 @@ def write_documents( policy=policy, ) - batched_documents = get_batches_from_generator( - document_objects, self.write_batch_size - ) - with tqdm( - total=len(document_objects), disable=not self.progress_bar - ) as progress_bar: + batched_documents = get_batches_from_generator(document_objects, self.write_batch_size) + with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: batch = convert_haystack_documents_to_qdrant_points( document_batch, @@ -477,9 +458,7 @@ def get_documents_generator( with_vectors=True, ) stop_scrolling = next_offset is None or ( - isinstance(next_offset, grpc.PointId) - and next_offset.num == 0 - and next_offset.uuid == "" + isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == "" ) for record in records: @@ -516,9 +495,7 @@ def get_documents_by_id( for record in records: documents.append( - convert_qdrant_point_to_haystack_document( - record, use_sparse_embeddings=self.use_sparse_embeddings - ) + convert_qdrant_point_to_haystack_document(record, use_sparse_embeddings=self.use_sparse_embeddings) ) return documents @@ -575,9 +552,7 @@ def _query_by_sparse( score_threshold=score_threshold, ) results = [ - convert_qdrant_point_to_haystack_document( - point, use_sparse_embeddings=self.use_sparse_embeddings - ) + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) for point in points ] if scale_score: @@ -625,9 +600,7 @@ def _query_by_embedding( score_threshold=score_threshold, ) results = [ - convert_qdrant_point_to_haystack_document( - point, use_sparse_embeddings=self.use_sparse_embeddings - ) + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) for point in points ] if scale_score: @@ -717,17 +690,12 @@ def _query_hybrid( raise QdrantStoreError(msg) from e try: - points = reciprocal_rank_fusion( - responses=[dense_request_response, sparse_request_response], limit=top_k - ) + points = reciprocal_rank_fusion(responses=[dense_request_response, sparse_request_response], limit=top_k) except Exception as e: msg = "Error while applying Reciprocal Rank Fusion" raise QdrantStoreError(msg) from e - results = [ - convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) - for point in points - ] + results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points] return results @@ -752,9 +720,7 @@ def get_distance(self, similarity: str) -> rest.Distance: ) raise QdrantStoreError(msg) from ke - def _create_payload_index( - self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None - ): + def _create_payload_index(self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None): """ Create payload index for the collection if payload_fields_to_index is provided See: https://qdrant.tech/documentation/concepts/indexing/#payload-index @@ -809,12 +775,7 @@ def _set_up_collection( # There is no need to verify the current configuration of that # collection. It might be just recreated again or does not exist yet. self.recreate_collection( - collection_name, - distance, - embedding_dim, - on_disk, - use_sparse_embeddings, - sparse_idf, + collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings, sparse_idf ) # Create Payload index if payload_fields_to_index is provided self._create_payload_index(collection_name, payload_fields_to_index) @@ -847,12 +808,8 @@ def _set_up_collection( raise QdrantStoreError(msg) if self.use_sparse_embeddings: - current_distance = collection_info.config.params.vectors[ - DENSE_VECTORS_NAME - ].distance - current_vector_size = collection_info.config.params.vectors[ - DENSE_VECTORS_NAME - ].size + current_distance = collection_info.config.params.vectors[DENSE_VECTORS_NAME].distance + current_vector_size = collection_info.config.params.vectors[DENSE_VECTORS_NAME].size else: current_distance = collection_info.config.params.vectors.distance current_vector_size = collection_info.config.params.vectors.size @@ -907,9 +864,7 @@ def recreate_collection( use_sparse_embeddings = self.use_sparse_embeddings # dense vectors configuration - vectors_config = rest.VectorParams( - size=embedding_dim, on_disk=on_disk, distance=distance - ) + vectors_config = rest.VectorParams(size=embedding_dim, on_disk=on_disk, distance=distance) if use_sparse_embeddings: # in this case, we need to define named vectors @@ -930,9 +885,7 @@ def recreate_collection( self.client.create_collection( collection_name=collection_name, vectors_config=vectors_config, - sparse_vectors_config=sparse_vectors_config - if use_sparse_embeddings - else None, + sparse_vectors_config=sparse_vectors_config if use_sparse_embeddings else None, shard_number=self.shard_number, replication_factor=self.replication_factor, write_consistency_factor=self.write_consistency_factor, @@ -963,24 +916,18 @@ def _handle_duplicate_documents( index = index or self.index if policy in (DuplicatePolicy.SKIP, DuplicatePolicy.FAIL): documents = self._drop_duplicate_documents(documents, index) - documents_found = self.get_documents_by_id( - ids=[doc.id for doc in documents], index=index - ) + documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents], index=index) ids_exist_in_db: List[str] = [doc.id for doc in documents_found] if len(ids_exist_in_db) > 0 and policy == DuplicatePolicy.FAIL: msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{index}'." raise DuplicateDocumentError(msg) - documents = list( - filter(lambda doc: doc.id not in ids_exist_in_db, documents) - ) + documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) return documents - def _drop_duplicate_documents( - self, documents: List[Document], index: Optional[str] = None - ) -> List[Document]: + def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: """ Drop duplicate documents based on same hash ID. diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index d76d8ac4b..69fd7cbbd 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -9,8 +9,7 @@ def convert_filters_to_qdrant( - filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, - is_parent_call: bool = True, + filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True ) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]: """Converts Haystack filters to the format used by Qdrant. @@ -50,9 +49,7 @@ def convert_filters_to_qdrant( operator = item.get("operator") # Check for repeated similar operators on each level - same_operator_flag = ( - operator in current_level_operators and operator in LOGICAL_OPERATORS - ) + same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS if not same_operator_flag: current_level_operators.append(operator) @@ -66,32 +63,19 @@ def convert_filters_to_qdrant( if operator in LOGICAL_OPERATORS: # Recursively process nested conditions - current_filter = ( - convert_filters_to_qdrant( - item.get("conditions", []), is_parent_call=False - ) - or [] - ) + current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or [] # When same_operator_flag is set to True, # ensure each clause is appended as an independent list to avoid merging distinct clauses. if operator == "AND": - must_clauses = ( - [must_clauses, current_filter] - if same_operator_flag - else must_clauses + current_filter - ) + must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter elif operator == "OR": should_clauses = ( - [should_clauses, current_filter] - if same_operator_flag - else should_clauses + current_filter + [should_clauses, current_filter] if same_operator_flag else should_clauses + current_filter ) elif operator == "NOT": must_not_clauses = ( - [must_not_clauses, current_filter] - if same_operator_flag - else must_not_clauses + current_filter + [must_not_clauses, current_filter] if same_operator_flag else must_not_clauses + current_filter ) elif operator in COMPARISON_OPERATORS: @@ -101,9 +85,7 @@ def convert_filters_to_qdrant( msg = f"'field' or 'value' not found for '{operator}'" raise FilterError(msg) - parsed_conditions = _parse_comparison_operation( - comparison_operation=operator, key=field, value=value - ) + parsed_conditions = _parse_comparison_operation(comparison_operation=operator, key=field, value=value) # check if the parsed_conditions are models.Filter or models.Condition for condition in parsed_conditions: @@ -238,9 +220,7 @@ def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Conditi return models.FieldCondition(key=key, match=models.MatchValue(value=value)) -def _build_in_condition( - key: str, value: List[models.ValueVariants] -) -> models.Condition: +def _build_in_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: if not isinstance(value, list): msg = f"Value {value} is not a list" raise FilterError(msg) @@ -262,17 +242,13 @@ def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Conditi ( models.FieldCondition(key=key, match=models.MatchText(text=value)) if isinstance(value, str) and " " not in value - else models.FieldCondition( - key=key, match=models.MatchValue(value=value) - ) + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) ) ] ) -def _build_nin_condition( - key: str, value: List[models.ValueVariants] -) -> models.Condition: +def _build_nin_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: if not isinstance(value, list): msg = f"Value {value} is not a list" raise FilterError(msg) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/migrate_to_sparse.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/migrate_to_sparse.py index 802825226..1fabbfd9c 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/migrate_to_sparse.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/migrate_to_sparse.py @@ -9,9 +9,7 @@ logger.setLevel(logging.INFO) -def migrate_to_sparse_embeddings_support( - old_document_store: QdrantDocumentStore, new_index: str -): +def migrate_to_sparse_embeddings_support(old_document_store: QdrantDocumentStore, new_index: str): """ Utility function to migrate an existing `QdrantDocumentStore` to a new one with support for sparse embeddings. @@ -125,7 +123,5 @@ def migrate_to_sparse_embeddings_support( # restore the original indexing threshold (re-enable indexing) client.update_collection( collection_name=new_index, - optimizer_config=models.OptimizersConfigDiff( - indexing_threshold=original_indexing_threshold - ), + optimizer_config=models.OptimizersConfigDiff(indexing_threshold=original_indexing_threshold), ) diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py index c1d02c920..242c4cafe 100644 --- a/integrations/qdrant/tests/test_converters.py +++ b/integrations/qdrant/tests/test_converters.py @@ -13,6 +13,7 @@ def test_convert_id_is_deterministic(): def test_point_to_document_reverts_proper_structure_from_record_with_sparse(): + point = rest.Record( id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", payload={ @@ -29,21 +30,17 @@ def test_point_to_document_reverts_proper_structure_from_record_with_sparse(): "text-sparse": {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]}, }, ) - document = convert_qdrant_point_to_haystack_document( - point, use_sparse_embeddings=True - ) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) assert "my-id" == document.id assert "Lorem ipsum" == document.content assert "text" == document.content_type - assert { - "indices": [7, 1024, 367], - "values": [0.1, 0.98, 0.33], - } == document.sparse_embedding.to_dict() + assert {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]} == document.sparse_embedding.to_dict() assert {"test_field": 1} == document.meta assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) def test_point_to_document_reverts_proper_structure_from_record_without_sparse(): + point = rest.Record( id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", payload={ @@ -57,9 +54,7 @@ def test_point_to_document_reverts_proper_structure_from_record_without_sparse() }, vector=[1.0, 0.0, 0.0, 0.0], ) - document = convert_qdrant_point_to_haystack_document( - point, use_sparse_embeddings=False - ) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=False) assert "my-id" == document.id assert "Lorem ipsum" == document.content assert "text" == document.content_type diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 248ade0ec..3871dbff0 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -54,11 +54,7 @@ def test_from_dict(): { "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", "init_parameters": { - "api_key": { - "env_vars": ["ENV_VAR"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "location": ":memory:", "index": "test", "embedding_dim": 768, diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index ec7c2c5f7..c388a10cf 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -20,9 +20,7 @@ from qdrant_client.http import models as rest -class TestQdrantDocumentStore( - CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest -): +class TestQdrantDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): @pytest.fixture def document_store(self) -> QdrantDocumentStore: return QdrantDocumentStore( @@ -34,15 +32,11 @@ def document_store(self) -> QdrantDocumentStore: ) def test_init_is_lazy(self): - with patch( - "haystack_integrations.document_stores.qdrant.document_store.qdrant_client" - ) as mocked_qdrant: + with patch("haystack_integrations.document_stores.qdrant.document_store.qdrant_client") as mocked_qdrant: QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) mocked_qdrant.assert_not_called() - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. This is used in every test. @@ -78,32 +72,23 @@ def test_sparse_configuration(self): assert sparse_config[SPARSE_VECTORS_NAME].modifier == rest.Modifier.IDF def test_query_hybrid(self, generate_sparse_embedding): - document_store = QdrantDocumentStore( - location=":memory:", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) docs = [] for i in range(20): docs.append( Document( - content=f"doc {i}", - sparse_embedding=generate_sparse_embedding(), - embedding=_random_embeddings(768), + content=f"doc {i}", sparse_embedding=generate_sparse_embedding(), embedding=_random_embeddings(768) ) ) document_store.write_documents(docs) - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 results: List[Document] = document_store._query_hybrid( - query_sparse_embedding=sparse_embedding, - query_embedding=embedding, - top_k=10, - return_embedding=True, + query_sparse_embedding=sparse_embedding, query_embedding=embedding, top_k=10, return_embedding=True ) assert len(results) == 10 @@ -112,53 +97,35 @@ def test_query_hybrid(self, generate_sparse_embedding): assert document.embedding def test_query_hybrid_fail_without_sparse_embedding(self, document_store): - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 with pytest.raises(QdrantStoreError): + document_store._query_hybrid( query_sparse_embedding=sparse_embedding, query_embedding=embedding, ) def test_query_hybrid_search_batch_failure(self): - document_store = QdrantDocumentStore( - location=":memory:", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 - with patch.object( - document_store.client, - "search_batch", - side_effect=Exception("search_batch error"), - ): + with patch.object(document_store.client, "search_batch", side_effect=Exception("search_batch error")): + with pytest.raises(QdrantStoreError): - document_store._query_hybrid( - query_sparse_embedding=sparse_embedding, query_embedding=embedding - ) + document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding) - @patch( - "haystack_integrations.document_stores.qdrant.document_store.reciprocal_rank_fusion" - ) + @patch("haystack_integrations.document_stores.qdrant.document_store.reciprocal_rank_fusion") def test_query_hybrid_reciprocal_rank_fusion_failure(self, mocked_fusion): - document_store = QdrantDocumentStore( - location=":memory:", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 mocked_fusion.side_effect = Exception("reciprocal_rank_fusion error") with pytest.raises(QdrantStoreError): - document_store._query_hybrid( - query_sparse_embedding=sparse_embedding, query_embedding=embedding - ) + document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 98404f4cd..fd070bda9 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -18,34 +18,22 @@ def document_store(self) -> QdrantDocumentStore: wait_result_from_api=True, ) - def test_filter_documents_with_qdrant_filters( - self, document_store, filterable_docs - ): + def test_filter_documents_with_qdrant_filters(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) result = document_store.filter_documents( filters=models.Filter( must_not=[ - models.FieldCondition( - key="meta.number", match=models.MatchValue(value=100) - ), - models.FieldCondition( - key="meta.name", match=models.MatchValue(value="name_0") - ), + models.FieldCondition(key="meta.number", match=models.MatchValue(value=100)), + models.FieldCondition(key="meta.name", match=models.MatchValue(value="name_0")), ] ) ) self.assert_documents_are_equal( result, - [ - d - for d in filterable_docs - if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0") - ], + [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], ) - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. This is used in every test. @@ -70,35 +58,22 @@ def test_not_operator(self, document_store, filterable_docs): ) self.assert_documents_are_equal( result, - [ - d - for d in filterable_docs - if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0") - ], + [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], ) def test_filter_criteria(self, document_store): documents = [ Document( content="This is test document 1.", - meta={ - "file_name": "file1", - "classification": {"details": {"category1": 0.9, "category2": 0.3}}, - }, + meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}}, ), Document( content="This is test document 2.", - meta={ - "file_name": "file2", - "classification": {"details": {"category1": 0.1, "category2": 0.7}}, - }, + meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}}, ), Document( content="This is test document 3.", - meta={ - "file_name": "file3", - "classification": {"details": {"category1": 0.7, "category2": 0.9}}, - }, + meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}}, ), ] @@ -106,24 +81,12 @@ def test_filter_criteria(self, document_store): filter_criteria = { "operator": "AND", "conditions": [ - { - "field": "meta.file_name", - "operator": "in", - "value": ["file1", "file2"], - }, + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2"]}, { "operator": "OR", "conditions": [ - { - "field": "meta.classification.details.category1", - "operator": ">=", - "value": 0.85, - }, - { - "field": "meta.classification.details.category2", - "operator": ">=", - "value": 0.85, - }, + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85}, ], }, ], @@ -136,14 +99,8 @@ def test_filter_criteria(self, document_store): for d in documents if (d.meta.get("file_name") in ["file1", "file2"]) and ( - ( - d.meta.get("classification").get("details").get("category1") - >= 0.85 - ) - or ( - d.meta.get("classification").get("details").get("category2") - >= 0.85 - ) + (d.meta.get("classification").get("details").get("category1") >= 0.85) + or (d.meta.get("classification").get("details").get("category2") >= 0.85) ) ], ) @@ -154,39 +111,21 @@ def test_complex_filter_criteria(self, document_store): content="This is test document 1.", meta={ "file_name": "file1", - "classification": { - "details": { - "category1": 0.45, - "category2": 0.5, - "category3": 0.2, - } - }, + "classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}}, }, ), Document( content="This is test document 2.", meta={ "file_name": "file2", - "classification": { - "details": { - "category1": 0.95, - "category2": 0.85, - "category3": 0.4, - } - }, + "classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}}, }, ), Document( content="This is test document 3.", meta={ "file_name": "file3", - "classification": { - "details": { - "category1": 0.85, - "category2": 0.7, - "category3": 0.95, - } - }, + "classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}}, }, ), ] @@ -195,32 +134,16 @@ def test_complex_filter_criteria(self, document_store): filter_criteria = { "operator": "AND", "conditions": [ - { - "field": "meta.file_name", - "operator": "in", - "value": ["file1", "file2", "file3"], - }, + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2", "file3"]}, { "operator": "AND", "conditions": [ - { - "field": "meta.classification.details.category1", - "operator": ">=", - "value": 0.85, - }, + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, { "operator": "OR", "conditions": [ - { - "field": "meta.classification.details.category2", - "operator": ">=", - "value": 0.8, - }, - { - "field": "meta.classification.details.category3", - "operator": ">=", - "value": 0.9, - }, + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.8}, + {"field": "meta.classification.details.category3", "operator": ">=", "value": 0.9}, ], }, ], @@ -235,19 +158,10 @@ def test_complex_filter_criteria(self, document_store): for d in documents if (d.meta.get("file_name") in ["file1", "file2", "file3"]) and ( - ( - d.meta.get("classification").get("details").get("category1") - >= 0.85 - ) + (d.meta.get("classification").get("details").get("category1") >= 0.85) and ( - ( - d.meta.get("classification").get("details").get("category2") - >= 0.8 - ) - or ( - d.meta.get("classification").get("details").get("category3") - >= 0.9 - ) + (d.meta.get("classification").get("details").get("category2") >= 0.8) + or (d.meta.get("classification").get("details").get("category3") >= 0.9) ) ) ], @@ -258,69 +172,46 @@ def test_complex_filter_criteria(self, document_store): def test_comparison_equal_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": "==", "value": None} - ) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if d.meta.get("number") is None] - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is None]) def test_comparison_not_equal_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": "!=", "value": None} - ) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if d.meta.get("number") is not None] - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "!=", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is not None]) def test_comparison_greater_than_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": ">", "value": None} - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) self.assert_documents_are_equal(result, []) - def test_comparison_greater_than_equal_with_none( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": ">=", "value": None} - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) self.assert_documents_are_equal(result, []) def test_comparison_less_than_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": "<", "value": None} - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) self.assert_documents_are_equal(result, []) - def test_comparison_less_than_equal_with_none( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): - result = document_store.filter_documents( - filters={"field": "meta.number", "operator": "<=", "value": None} - ) + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) self.assert_documents_are_equal(result, []) # ======== ========================== ======== @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") - def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") - def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Cannot distinguish errors yet") - def test_missing_top_level_operator_key(self, document_store, filterable_docs): - ... + def test_missing_top_level_operator_key(self, document_store, filterable_docs): ... diff --git a/integrations/qdrant/tests/test_legacy_filters.py b/integrations/qdrant/tests/test_legacy_filters.py index a3216c92d..60f1fad2b 100644 --- a/integrations/qdrant/tests/test_legacy_filters.py +++ b/integrations/qdrant/tests/test_legacy_filters.py @@ -26,9 +26,7 @@ def document_store(self) -> QdrantDocumentStore: wait_result_from_api=True, ) - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. This is used in every test. @@ -40,97 +38,56 @@ def assert_documents_are_equal( # Check that the sets are equal, meaning the content and IDs match regardless of order assert {doc.id for doc in received} == {doc.id for doc in expected} - def test_filter_simple_metadata_value( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_metadata_value(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": "100"}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_filter_document_dataframe( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - def test_eq_filter_explicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": {"$eq": "100"}}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - def test_eq_filter_implicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": "100"}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_eq_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_eq_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsNotEqualTest - def test_ne_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": {"$ne": "100"}}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_ne_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_ne_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsInTest - def test_filter_simple_list_single_element( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_list_single_element(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": ["100"]}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - def test_filter_simple_list_one_value( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_list_one_value(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": ["100"]}) - self.assert_documents_are_equal( - result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]] - ) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) - def test_filter_simple_list( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_list(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": ["100", "123"]}) self.assert_documents_are_equal( @@ -138,224 +95,136 @@ def test_filter_simple_list( [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], ) - def test_incorrect_filter_value( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_incorrect_filter_value(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.page": ["nope"]}) self.assert_documents_are_equal(result, []) - def test_in_filter_explicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_in_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.page": {"$in": ["100", "123", "n.a."]}} - ) + result = document_store.filter_documents(filters={"meta.page": {"$in": ["100", "123", "n.a."]}}) self.assert_documents_are_equal( result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], ) - def test_in_filter_implicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.page": ["100", "123", "n.a."]} - ) + result = document_store.filter_documents(filters={"meta.page": ["100", "123", "n.a."]}) self.assert_documents_are_equal( result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], ) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_in_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_in_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsNotInTest @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_nin_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_nin_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - def test_nin_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.page": {"$nin": ["100", "123", "n.a."]}} - ) + result = document_store.filter_documents(filters={"meta.page": {"$nin": ["100", "123", "n.a."]}}) self.assert_documents_are_equal( result, - [ - doc - for doc in filterable_docs - if doc.meta.get("page") not in ["100", "123"] - ], + [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]], ) # LegacyFilterDocumentsGreaterThanTest - def test_gt_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_gt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.number": {"$gt": 0.0}}) self.assert_documents_are_equal( result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] > 0 - ], + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0], ) - def test_gt_filter_non_numeric( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): document_store.filter_documents(filters={"meta.page": {"$gt": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gt_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gt_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsGreaterThanEqualTest - def test_gte_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_gte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.number": {"$gte": -2}}) self.assert_documents_are_equal( result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] >= -2 - ], + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2], ) - def test_gte_filter_non_numeric( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): document_store.filter_documents(filters={"meta.page": {"$gte": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gte_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gte_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsLessThanTest - def test_lt_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_lt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.number": {"$lt": 0.0}}) self.assert_documents_are_equal( result, - [ - doc - for doc in filterable_docs - if doc.meta.get("number") is not None and doc.meta["number"] < 0 - ], + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] < 0], ) - def test_lt_filter_non_numeric( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): document_store.filter_documents(filters={"meta.page": {"$lt": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lt_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lt_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsLessThanEqualTest - def test_lte_filter( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_lte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0}}) self.assert_documents_are_equal( result, - [ - doc - for doc in filterable_docs - if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0 - ], + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0], ) - def test_lte_filter_non_numeric( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) with pytest.raises(FilterError): document_store.filter_documents(filters={"meta.page": {"$lte": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lte_filter_table( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lte_filter_embedding( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - ... + def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsSimpleLogicalTest - def test_filter_simple_or( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_or(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters = { "$or": { @@ -378,61 +247,43 @@ def test_filter_simple_implicit_and_with_multi_key_dict( self, document_store: DocumentStore, filterable_docs: List[Document] ): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.number": {"$lte": 2.0, "$gte": 0.0}} - ) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0.0}}) self.assert_documents_are_equal( result, [ doc for doc in filterable_docs - if "number" in doc.meta - and doc.meta["number"] >= 0.0 - and doc.meta["number"] <= 2.0 + if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 ], ) - def test_filter_simple_explicit_and_with_list( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_explicit_and_with_list(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.number": {"$and": [{"$lte": 2}, {"$gte": 0}]}} - ) + result = document_store.filter_documents(filters={"meta.number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) self.assert_documents_are_equal( result, [ doc for doc in filterable_docs - if "number" in doc.meta - and doc.meta["number"] <= 2.0 - and doc.meta["number"] >= 0.0 + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 ], ) - def test_filter_simple_implicit_and( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_simple_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"meta.number": {"$lte": 2.0, "$gte": 0}} - ) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0}}) self.assert_documents_are_equal( result, [ doc for doc in filterable_docs - if "number" in doc.meta - and doc.meta["number"] <= 2.0 - and doc.meta["number"] >= 0.0 + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 ], ) # LegacyFilterDocumentsNestedLogicalTest( - def test_filter_nested_implicit_and( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters_simplified = { "meta.number": {"$lte": 2, "$gte": 0}, @@ -453,9 +304,7 @@ def test_filter_nested_implicit_and( ], ) - def test_filter_nested_or( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_or(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters = { "$or": { @@ -476,9 +325,7 @@ def test_filter_nested_or( ], ) - def test_filter_nested_and_or_explicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_and_or_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters_simplified = { "$and": { @@ -505,9 +352,7 @@ def test_filter_nested_and_or_explicit( ], ) - def test_filter_nested_and_or_implicit( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_and_or_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters_simplified = { "meta.page": {"$eq": "123"}, @@ -532,9 +377,7 @@ def test_filter_nested_and_or_implicit( ], ) - def test_filter_nested_or_and( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): + def test_filter_nested_or_and(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) filters_simplified = { "$or": { @@ -553,10 +396,7 @@ def test_filter_nested_or_and( for doc in filterable_docs if ( (doc.meta.get("number") is not None and doc.meta["number"] < 1) - or ( - doc.meta.get("name") in ["name_0", "name_1"] - and (doc.meta.get("chapter") != "intro") - ) + or (doc.meta.get("name") in ["name_0", "name_1"] and (doc.meta.get("chapter") != "intro")) ) ], ) @@ -588,14 +428,8 @@ def test_filter_nested_multiple_identical_operators_same_level( doc for doc in filterable_docs if ( - ( - doc.meta.get("name") in ["name_0", "name_1"] - and doc.meta.get("page") == "100" - ) - or ( - doc.meta.get("chapter") in ["intro", "abstract"] - and doc.meta.get("page") == "123" - ) + (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") + or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") ) ], ) @@ -604,6 +438,4 @@ def test_no_filter_not_empty(self, document_store: DocumentStore): docs = [Document(content="test doc")] document_store.write_documents(docs) self.assert_documents_are_equal(document_store.filter_documents(), docs) - self.assert_documents_are_equal( - document_store.filter_documents(filters={}), docs - ) + self.assert_documents_are_equal(document_store.filter_documents(filters={}), docs) diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 7718bae2d..a92f6917f 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -18,9 +18,7 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin): def test_init_default(self): - document_store = QdrantDocumentStore( - location=":memory:", index="test", use_sparse_embeddings=False - ) + document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) retriever = QdrantEmbeddingRetriever(document_store=document_store) assert retriever._document_store == document_store assert retriever._filters is None @@ -29,20 +27,14 @@ def test_init_default(self): assert retriever._return_embedding is False assert retriever._score_threshold is None - retriever = QdrantEmbeddingRetriever( - document_store=document_store, filter_policy="replace" - ) + retriever = QdrantEmbeddingRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - QdrantEmbeddingRetriever( - document_store=document_store, filter_policy="invalid" - ) + QdrantEmbeddingRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): - document_store = QdrantDocumentStore( - location=":memory:", index="test", use_sparse_embeddings=False - ) + document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) retriever = QdrantEmbeddingRetriever(document_store=document_store) res = retriever.to_dict() assert res == { @@ -124,31 +116,23 @@ def test_from_dict(self): assert retriever._score_threshold is None def test_run(self, filterable_docs: List[Document]): - document_store = QdrantDocumentStore( - location=":memory:", index="Boi", use_sparse_embeddings=False - ) + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) document_store.write_documents(filterable_docs) retriever = QdrantEmbeddingRetriever(document_store=document_store) - results: List[Document] = retriever.run( - query_embedding=_random_embeddings(768) - )["documents"] + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] assert len(results) == 10 - results = retriever.run( - query_embedding=_random_embeddings(768), top_k=5, return_embedding=False - )["documents"] + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False)["documents"] assert len(results) == 5 for document in results: assert document.embedding is None def test_run_filters(self, filterable_docs: List[Document]): - document_store = QdrantDocumentStore( - location=":memory:", index="Boi", use_sparse_embeddings=False - ) + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) document_store.write_documents(filterable_docs) @@ -158,9 +142,7 @@ def test_run_filters(self, filterable_docs: List[Document]): filter_policy=FilterPolicy.MERGE, ) - results: List[Document] = retriever.run( - query_embedding=_random_embeddings(768) - )["documents"] + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] assert len(results) == 3 results = retriever.run( @@ -176,11 +158,7 @@ def test_run_filters(self, filterable_docs: List[Document]): def test_run_with_score_threshold(self): document_store = QdrantDocumentStore( - embedding_dim=4, - location=":memory:", - similarity="cosine", - index="Boi", - use_sparse_embeddings=False, + embedding_dim=4, location=":memory:", similarity="cosine", index="Boi", use_sparse_embeddings=False ) document_store.write_documents( @@ -196,31 +174,22 @@ def test_run_with_score_threshold(self): retriever = QdrantEmbeddingRetriever(document_store=document_store) results = retriever.run( - query_embedding=[0.9, 0.9, 0.9, 0.9], - top_k=5, - return_embedding=False, - score_threshold=0.5, + query_embedding=[0.9, 0.9, 0.9, 0.9], top_k=5, return_embedding=False, score_threshold=0.5 )["documents"] assert len(results) == 2 def test_run_with_sparse_activated(self, filterable_docs: List[Document]): - document_store = QdrantDocumentStore( - location=":memory:", index="Boi", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) document_store.write_documents(filterable_docs) retriever = QdrantEmbeddingRetriever(document_store=document_store) - results: List[Document] = retriever.run( - query_embedding=_random_embeddings(768) - )["documents"] + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] assert len(results) == 10 - results = retriever.run( - query_embedding=_random_embeddings(768), top_k=5, return_embedding=False - )["documents"] + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False)["documents"] assert len(results) == 5 @@ -239,15 +208,11 @@ def test_init_default(self): assert retriever._return_embedding is False assert retriever._score_threshold is None - retriever = QdrantSparseEmbeddingRetriever( - document_store=document_store, filter_policy="replace" - ) + retriever = QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - QdrantSparseEmbeddingRetriever( - document_store=document_store, filter_policy="invalid" - ) + QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test") @@ -357,9 +322,7 @@ def test_from_dict_no_filter_policy(self): assert retriever._score_threshold is None def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): - document_store = QdrantDocumentStore( - location=":memory:", index="Boi", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) # Add fake sparse embedding to documents for doc in filterable_docs: @@ -367,18 +330,12 @@ def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): document_store.write_documents(filterable_docs) retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) - results: List[Document] = retriever.run( - query_sparse_embedding=sparse_embedding - )["documents"] + results: List[Document] = retriever.run(query_sparse_embedding=sparse_embedding)["documents"] assert len(results) == 10 - results = retriever.run( - query_sparse_embedding=sparse_embedding, top_k=5, return_embedding=True - )["documents"] + results = retriever.run(query_sparse_embedding=sparse_embedding, top_k=5, return_embedding=True)["documents"] assert len(results) == 5 for document in results: @@ -387,9 +344,7 @@ def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): class TestQdrantHybridRetriever: def test_init_default(self): - document_store = QdrantDocumentStore( - location=":memory:", index="test", use_sparse_embeddings=True - ) + document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=True) retriever = QdrantHybridRetriever(document_store=document_store) assert retriever._document_store == document_store @@ -399,21 +354,15 @@ def test_init_default(self): assert retriever._return_embedding is False assert retriever._score_threshold is None - retriever = QdrantHybridRetriever( - document_store=document_store, filter_policy="replace" - ) + retriever = QdrantHybridRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - QdrantHybridRetriever( - document_store=document_store, filter_policy="invalid" - ) + QdrantHybridRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test") - retriever = QdrantHybridRetriever( - document_store=document_store, top_k=5, return_embedding=True - ) + retriever = QdrantHybridRetriever(document_store=document_store, top_k=5, return_embedding=True) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever", @@ -515,21 +464,14 @@ def test_from_dict_no_filter_policy(self): def test_run(self): mock_store = Mock(spec=QdrantDocumentStore) - sparse_embedding = SparseEmbedding( - indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33] - ) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) mock_store._query_hybrid.return_value = [ - Document( - content="Test doc", - embedding=[0.1, 0.2], - sparse_embedding=sparse_embedding, - ) + Document(content="Test doc", embedding=[0.1, 0.2], sparse_embedding=sparse_embedding) ] retriever = QdrantHybridRetriever(document_store=mock_store) res = retriever.run( - query_embedding=[0.5, 0.7], - query_sparse_embedding=SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]), + query_embedding=[0.5, 0.7], query_sparse_embedding=SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]) ) call_args = mock_store._query_hybrid.call_args diff --git a/integrations/ragas/example/example.py b/integrations/ragas/example/example.py index 52cec2305..ba75bdc7e 100644 --- a/integrations/ragas/example/example.py +++ b/integrations/ragas/example/example.py @@ -2,10 +2,7 @@ from haystack import Pipeline -from haystack_integrations.components.evaluators.ragas import ( - RagasEvaluator, - RagasMetric, -) +from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric QUESTIONS = [ "Which is the most popular global sport?", @@ -35,11 +32,7 @@ ) evaluator_aspect = RagasEvaluator( metric=RagasMetric.ASPECT_CRITIQUE, - metric_params={ - "name": "custom", - "definition": "Is this answer problematic for children?", - "strictness": 3, - }, + metric_params={"name": "custom", "definition": "Is this answer problematic for children?", "strictness": 3}, ) pipeline.add_component("evaluator_context", evaluator_context) pipeline.add_component("evaluator_aspect", evaluator_aspect) @@ -48,16 +41,8 @@ # Ragas class' documentation for more details. results = pipeline.run( { - "evaluator_context": { - "questions": QUESTIONS, - "contexts": CONTEXTS, - "ground_truths": GROUND_TRUTHS, - }, - "evaluator_aspect": { - "questions": QUESTIONS, - "contexts": CONTEXTS, - "responses": RESPONSES, - }, + "evaluator_context": {"questions": QUESTIONS, "contexts": CONTEXTS, "ground_truths": GROUND_TRUTHS}, + "evaluator_aspect": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES}, } ) diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py index e9f87c966..c44c446e6 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py @@ -64,9 +64,7 @@ def __init__( Refer to the `RagasMetric` class for more details on required parameters. """ - self.metric = ( - metric if isinstance(metric, RagasMetric) else RagasMetric.from_str(metric) - ) + self.metric = metric if isinstance(metric, RagasMetric) else RagasMetric.from_str(metric) self.metric_params = metric_params self.descriptor = METRIC_DESCRIPTORS[self.metric] @@ -84,9 +82,7 @@ def _init_metric(self): if self.metric_params is None: msg = f"Ragas metric '{self.metric}' expected init parameters but got none" raise ValueError(msg) - elif not all( - k in self.descriptor.init_parameters for k in self.metric_params.keys() - ): + elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()): msg = ( f"Invalid init parameters for Ragas metric '{self.metric}'. " f"Expected: {self.descriptor.init_parameters}" @@ -122,9 +118,7 @@ def run(self, **inputs) -> Dict[str, Any]: - `name` - The name of the metric. - `score` - The score of the metric. """ - InputConverters.validate_input_parameters( - self.metric, self.descriptor.input_parameters, inputs - ) + InputConverters.validate_input_parameters(self.metric, self.descriptor.input_parameters, inputs) converted_inputs: List[Dict[str, str]] = list(self.descriptor.input_converter(**inputs)) # type: ignore dataset = Dataset.from_list(converted_inputs) @@ -132,10 +126,7 @@ def run(self, **inputs) -> Dict[str, Any]: OutputConverters.validate_outputs(results) converted_results = [ - [result.to_dict()] - for result in self.descriptor.output_converter( - results, self.metric, self.metric_params - ) + [result.to_dict()] for result in self.descriptor.output_converter(results, self.metric, self.metric_params) ] return {"results": converted_results} diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py index d66e2756e..5d6ed16bc 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py @@ -129,9 +129,7 @@ class MetricDescriptor: backend: Type[Metric] input_parameters: Dict[str, Type] input_converter: Callable[[Any], Iterable[Dict[str, str]]] - output_converter: Callable[ - [Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult] - ] + output_converter: Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] init_parameters: Optional[List[str]] = None @classmethod @@ -141,9 +139,7 @@ def new( backend: Type[Metric], input_converter: Callable[[Any], Iterable[Dict[str, str]]], output_converter: Optional[ - Callable[ - [Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult] - ] + Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] ] = None, *, init_parameters: Optional[List[str]] = None, @@ -153,10 +149,7 @@ def new( for name, param in input_converter_signature.parameters.items(): if name in ("cls", "self"): continue - elif param.kind not in ( - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): + elif param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): continue input_parameters[name] = param.annotation @@ -165,9 +158,7 @@ def new( backend=backend, input_parameters=input_parameters, input_converter=input_converter, - output_converter=output_converter - if output_converter is not None - else OutputConverters.default, + output_converter=output_converter if output_converter is not None else OutputConverters.default, init_parameters=init_parameters, ) @@ -190,9 +181,7 @@ def _validate_input_elements(**kwargs) -> None: f"got '{type(collection).__name__}' instead" ) raise ValueError(msg) - elif not all(isinstance(x, str) for x in collection) and not all( - isinstance(x, list) for x in collection - ): + elif not all(isinstance(x, str) for x in collection) and not all(isinstance(x, list) for x in collection): msg = f"Ragas evaluator expects inputs to be of type 'str' or 'list' in '{k}'" raise ValueError(msg) @@ -216,9 +205,7 @@ def validate_input_parameters( def question_context_response( questions: List[str], contexts: List[List[str]], responses: List[str] ) -> Iterable[Dict[str, Union[str, List[str]]]]: - InputConverters._validate_input_elements( - questions=questions, contexts=contexts, responses=responses - ) + InputConverters._validate_input_elements(questions=questions, contexts=contexts, responses=responses) for q, c, r in zip(questions, contexts, responses): # type: ignore yield {"question": q, "contexts": c, "answer": r} @@ -228,9 +215,7 @@ def question_context_ground_truth( contexts: List[List[str]], ground_truths: List[str], ) -> Iterable[Dict[str, Union[str, List[str]]]]: - InputConverters._validate_input_elements( - questions=questions, contexts=contexts, ground_truths=ground_truths - ) + InputConverters._validate_input_elements(questions=questions, contexts=contexts, ground_truths=ground_truths) for q, c, gt in zip(questions, contexts, ground_truths): # type: ignore yield {"question": q, "contexts": c, "ground_truth": gt} @@ -248,9 +233,7 @@ def response_ground_truth( responses: List[str], ground_truths: List[str], ) -> Iterable[Dict[str, str]]: - InputConverters._validate_input_elements( - responses=responses, ground_truths=ground_truths - ) + InputConverters._validate_input_elements(responses=responses, ground_truths=ground_truths) for r, gt in zip(responses, ground_truths): # type: ignore yield {"answer": r, "ground_truth": gt} @@ -260,9 +243,7 @@ def question_response_ground_truth( responses: List[str], ground_truths: List[str], ) -> Iterable[Dict[str, str]]: - InputConverters._validate_input_elements( - questions=questions, ground_truths=ground_truths, responses=responses - ) + InputConverters._validate_input_elements(questions=questions, ground_truths=ground_truths, responses=responses) for q, r, gt in zip(questions, responses, ground_truths): # type: ignore yield {"question": q, "answer": r, "ground_truth": gt} @@ -281,30 +262,21 @@ def validate_outputs(outputs: Result) -> None: raise ValueError(msg) @staticmethod - def _extract_default_results( - output: Result, metric_name: str - ) -> List[MetricResult]: + def _extract_default_results(output: Result, metric_name: str) -> List[MetricResult]: try: output_scores: List[Dict[str, float]] = output.scores.to_list() - return [ - MetricResult(name=metric_name, score=metric_dict[metric_name]) - for metric_dict in output_scores - ] + return [MetricResult(name=metric_name, score=metric_dict[metric_name]) for metric_dict in output_scores] except KeyError as e: msg = f"Ragas evaluator did not return an expected output for metric '{e.args[0]}'" raise ValueError(msg) from e @staticmethod - def default( - output: Result, metric: RagasMetric, _: Optional[Dict] - ) -> List[MetricResult]: + def default(output: Result, metric: RagasMetric, _: Optional[Dict]) -> List[MetricResult]: metric_name = metric.value return OutputConverters._extract_default_results(output, metric_name) @staticmethod - def aspect_critique( - output: Result, _: RagasMetric, metric_params: Optional[Dict[str, Any]] - ) -> List[MetricResult]: + def aspect_critique(output: Result, _: RagasMetric, metric_params: Optional[Dict[str, Any]]) -> List[MetricResult]: if metric_params is None: msg = "Aspect critique metric requires metric parameters" raise ValueError(msg) diff --git a/integrations/ragas/tests/test_evaluator.py b/integrations/ragas/tests/test_evaluator.py index a72641d59..fc8901c32 100644 --- a/integrations/ragas/tests/test_evaluator.py +++ b/integrations/ragas/tests/test_evaluator.py @@ -5,10 +5,7 @@ import pytest from datasets import Dataset from haystack import DeserializationError -from haystack_integrations.components.evaluators.ragas import ( - RagasEvaluator, - RagasMetric, -) +from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric from ragas.evaluation import Result from ragas.metrics.base import Metric @@ -47,30 +44,14 @@ def __init__(self, metric: RagasMetric) -> None: def evaluate(self, _, metric: Metric, **kwargs): output_map = { - RagasMetric.ANSWER_CORRECTNESS: Result( - scores=Dataset.from_list([{"answer_correctness": 0.5}]) - ), - RagasMetric.FAITHFULNESS: Result( - scores=Dataset.from_list([{"faithfulness": 1.0}]) - ), - RagasMetric.ANSWER_SIMILARITY: Result( - scores=Dataset.from_list([{"answer_similarity": 1.0}]) - ), - RagasMetric.CONTEXT_PRECISION: Result( - scores=Dataset.from_list([{"context_precision": 0.5}]) - ), - RagasMetric.CONTEXT_UTILIZATION: Result( - scores=Dataset.from_list([{"context_utilization": 1.0}]) - ), - RagasMetric.CONTEXT_RECALL: Result( - scores=Dataset.from_list([{"context_recall": 0.9}]) - ), - RagasMetric.ASPECT_CRITIQUE: Result( - scores=Dataset.from_list([{"harmfulness": 1.0}]) - ), - RagasMetric.ANSWER_RELEVANCY: Result( - scores=Dataset.from_list([{"answer_relevancy": 0.4}]) - ), + RagasMetric.ANSWER_CORRECTNESS: Result(scores=Dataset.from_list([{"answer_correctness": 0.5}])), + RagasMetric.FAITHFULNESS: Result(scores=Dataset.from_list([{"faithfulness": 1.0}])), + RagasMetric.ANSWER_SIMILARITY: Result(scores=Dataset.from_list([{"answer_similarity": 1.0}])), + RagasMetric.CONTEXT_PRECISION: Result(scores=Dataset.from_list([{"context_precision": 0.5}])), + RagasMetric.CONTEXT_UTILIZATION: Result(scores=Dataset.from_list([{"context_utilization": 1.0}])), + RagasMetric.CONTEXT_RECALL: Result(scores=Dataset.from_list([{"context_recall": 0.9}])), + RagasMetric.ASPECT_CRITIQUE: Result(scores=Dataset.from_list([{"harmfulness": 1.0}])), + RagasMetric.ANSWER_RELEVANCY: Result(scores=Dataset.from_list([{"answer_relevancy": 0.4}])), } assert isinstance(metric, Metric) return output_map[self.metric] @@ -147,9 +128,7 @@ def test_evaluator_serde(): assert eval.metric == new_eval.metric assert eval.metric_params == new_eval.metric_params - with pytest.raises( - DeserializationError, match=r"cannot serialize the metric parameters" - ): + with pytest.raises(DeserializationError, match=r"cannot serialize the metric parameters"): init_params3 = copy.deepcopy(init_params) init_params3["metric_params"]["name"] = Unserializable("") eval = RagasEvaluator(**init_params3) @@ -164,31 +143,11 @@ def test_evaluator_serde(): {"questions": [], "responses": [], "ground_truths": []}, {"weights": [0.5, 0.5]}, ), - ( - RagasMetric.FAITHFULNESS, - {"questions": [], "contexts": [], "responses": []}, - None, - ), - ( - RagasMetric.ANSWER_SIMILARITY, - {"responses": [], "ground_truths": []}, - {"threshold": 0.5}, - ), - ( - RagasMetric.CONTEXT_PRECISION, - {"questions": [], "contexts": [], "ground_truths": []}, - None, - ), - ( - RagasMetric.CONTEXT_UTILIZATION, - {"questions": [], "contexts": [], "responses": []}, - None, - ), - ( - RagasMetric.CONTEXT_RECALL, - {"questions": [], "contexts": [], "ground_truths": []}, - None, - ), + (RagasMetric.FAITHFULNESS, {"questions": [], "contexts": [], "responses": []}, None), + (RagasMetric.ANSWER_SIMILARITY, {"responses": [], "ground_truths": []}, {"threshold": 0.5}), + (RagasMetric.CONTEXT_PRECISION, {"questions": [], "contexts": [], "ground_truths": []}, None), + (RagasMetric.CONTEXT_UTILIZATION, {"questions": [], "contexts": [], "responses": []}, None), + (RagasMetric.CONTEXT_RECALL, {"questions": [], "contexts": [], "ground_truths": []}, None), ( RagasMetric.ASPECT_CRITIQUE, {"questions": [], "contexts": [], "responses": []}, @@ -199,11 +158,7 @@ def test_evaluator_serde(): "large?", }, ), - ( - RagasMetric.ANSWER_RELEVANCY, - {"questions": [], "contexts": [], "responses": []}, - {"strictness": 2}, - ), + (RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, {"strictness": 2}), ], ) def test_evaluator_valid_inputs(current_metric, inputs, params): @@ -212,9 +167,7 @@ def test_evaluator_valid_inputs(current_metric, inputs, params): "metric_params": params, } eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend( - current_metric - ).evaluate(dataset, metric) + eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) output = eval.run(**inputs) @@ -233,12 +186,7 @@ def test_evaluator_valid_inputs(current_metric, inputs, params): "Mismatching counts ", {"strictness": 2}, ), - ( - RagasMetric.ANSWER_RELEVANCY, - {"responses": []}, - "expected input parameter ", - {"strictness": 2}, - ), + (RagasMetric.ANSWER_RELEVANCY, {"responses": []}, "expected input parameter ", {"strictness": 2}), ], ) def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): @@ -248,9 +196,7 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): "metric_params": params, } eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend( - current_metric - ).evaluate(dataset, metric) + eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) output = eval.run(**inputs) @@ -320,9 +266,7 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para "metric_params": metric_params, } eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend( - current_metric - ).evaluate(dataset, metric) + eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) results = eval.run(**inputs)["results"] assert type(results) == type(expected_outputs) @@ -331,10 +275,7 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para for r, o in zip(results, expected_outputs): assert len(r) == len(o) - expected = { - (name if name is not None else str(current_metric), score) - for name, score in o - } + expected = {(name if name is not None else str(current_metric), score) for name, score in o} got = {(x["name"], x["score"]) for x in r} assert got == expected @@ -349,20 +290,12 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para [ ( RagasMetric.ANSWER_CORRECTNESS, - { - "questions": DEFAULT_QUESTIONS, - "responses": DEFAULT_RESPONSES, - "ground_truths": DEFAULT_GROUND_TRUTHS, - }, + {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES, "ground_truths": DEFAULT_GROUND_TRUTHS}, {"weights": [0.5, 0.5]}, ), ( RagasMetric.FAITHFULNESS, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, None, ), ( @@ -372,38 +305,22 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para ), ( RagasMetric.CONTEXT_PRECISION, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "ground_truths": DEFAULT_GROUND_TRUTHS, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS}, None, ), ( RagasMetric.CONTEXT_UTILIZATION, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, None, ), ( RagasMetric.CONTEXT_RECALL, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "ground_truths": DEFAULT_GROUND_TRUTHS, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS}, None, ), ( RagasMetric.ASPECT_CRITIQUE, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, { "name": "harmfulness", "definition": "Does the submission causes or has the potential to " @@ -413,11 +330,7 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para ), ( RagasMetric.ANSWER_RELEVANCY, - { - "questions": DEFAULT_QUESTIONS, - "contexts": DEFAULT_CONTEXTS, - "responses": DEFAULT_RESPONSES, - }, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, {"strictness": 2}, ), ], diff --git a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py index 88a27ef10..637c0840f 100644 --- a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py +++ b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py @@ -46,9 +46,7 @@ class UnstructuredFileConverter: def __init__( self, api_url: str = UNSTRUCTURED_HOSTED_API_URL, - api_key: Optional[Secret] = Secret.from_env_var( - "UNSTRUCTURED_API_KEY", strict=False - ), # noqa: B008 + api_key: Optional[Secret] = Secret.from_env_var("UNSTRUCTURED_API_KEY", strict=False), # noqa: B008 document_creation_mode: Literal[ "one-doc-per-file", "one-doc-per-page", "one-doc-per-element" ] = "one-doc-per-file", @@ -148,11 +146,7 @@ def run( paths_obj = [Path(path) for path in paths] filepaths = [path for path in paths_obj if path.is_file()] filepaths_in_directories = [ - filepath - for path in paths_obj - if path.is_dir() - for filepath in path.glob("*.*") - if filepath.is_file() + filepath for path in paths_obj if path.is_dir() for filepath in path.glob("*.*") if filepath.is_file() ] if filepaths_in_directories and isinstance(meta, list): error = """"If providing directories in the `paths` parameter, @@ -167,9 +161,7 @@ def run( meta_list = normalize_metadata(meta, sources_count=len(all_filepaths)) for filepath, metadata in tqdm( - zip(all_filepaths, meta_list), - desc="Converting files to Haystack Documents", - disable=not self.progress_bar, + zip(all_filepaths, meta_list), desc="Converting files to Haystack Documents", disable=not self.progress_bar ): elements = self._partition_file_into_elements(filepath=filepath) docs_for_file = self._create_documents( @@ -186,9 +178,7 @@ def run( def _create_documents( filepath: Path, elements: List[Element], - document_creation_mode: Literal[ - "one-doc-per-file", "one-doc-per-page", "one-doc-per-element" - ], + document_creation_mode: Literal["one-doc-per-file", "one-doc-per-page", "one-doc-per-element"], separator: str, meta: Dict[str, Any], ) -> List[Document]: @@ -216,10 +206,7 @@ def _create_documents( texts_per_page[page_number] += str(el) + separator meta_per_page[page_number].update(metadata) - docs = [ - Document(content=texts_per_page[page], meta=meta_per_page[page]) - for page in texts_per_page.keys() - ] + docs = [Document(content=texts_per_page[page], meta=meta_per_page[page]) for page in texts_per_page.keys()] elif document_creation_mode == "one-doc-per-element": for index, el in enumerate(elements): @@ -247,7 +234,5 @@ def _partition_file_into_elements(self, filepath: Path) -> List[Element]: **self.unstructured_kwargs, ) except Exception as e: - logger.warning( - f"Unstructured could not process file {filepath}. Error: {e}" - ) + logger.warning(f"Unstructured could not process file {filepath}. Error: {e}") return elements diff --git a/integrations/unstructured/tests/test_converter.py b/integrations/unstructured/tests/test_converter.py index 15d551570..5d1a6c091 100644 --- a/integrations/unstructured/tests/test_converter.py +++ b/integrations/unstructured/tests/test_converter.py @@ -2,9 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import pytest -from haystack_integrations.components.converters.unstructured import ( - UnstructuredFileConverter, -) +from haystack_integrations.components.converters.unstructured import UnstructuredFileConverter class TestUnstructuredFileConverter: @@ -35,9 +33,7 @@ def test_init_with_parameters(self): def test_init_hosted_without_api_key_raises_error(self): with pytest.raises(ValueError): - UnstructuredFileConverter( - api_url="https://api.unstructured.io/general/v0/general" - ) + UnstructuredFileConverter(api_url="https://api.unstructured.io/general/v0/general") @pytest.mark.usefixtures("set_env_variables") def test_to_dict(self): @@ -48,11 +44,7 @@ def test_to_dict(self): "type": "haystack_integrations.components.converters.unstructured.converter.UnstructuredFileConverter", "init_parameters": { "api_url": "https://api.unstructured.io/general/v0/general", - "api_key": { - "env_vars": ["UNSTRUCTURED_API_KEY"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["UNSTRUCTURED_API_KEY"], "strict": False, "type": "env_var"}, "document_creation_mode": "one-doc-per-file", "separator": "\n\n", "unstructured_kwargs": {}, @@ -66,11 +58,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.components.converters.unstructured.converter.UnstructuredFileConverter", "init_parameters": { "api_url": "http://custom-url:8000/general", - "api_key": { - "env_vars": ["UNSTRUCTURED_API_KEY"], - "strict": False, - "type": "env_var", - }, + "api_key": {"env_vars": ["UNSTRUCTURED_API_KEY"], "strict": False, "type": "env_var"}, "document_creation_mode": "one-doc-per-element", "separator": "|", "unstructured_kwargs": {"foo": "bar"}, @@ -90,8 +78,7 @@ def test_run_one_doc_per_file(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-file", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-file" ) documents = local_converter.run([pdf_path])["documents"] @@ -104,8 +91,7 @@ def test_run_one_doc_per_page(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-page", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-page" ) documents = local_converter.run([pdf_path])["documents"] @@ -120,8 +106,7 @@ def test_run_one_doc_per_element(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-element", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) documents = local_converter.run([pdf_path])["documents"] @@ -139,8 +124,7 @@ def test_run_one_doc_per_file_with_meta(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" meta = {"custom_meta": "foobar"} local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-file", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-file" ) documents = local_converter.run(paths=[pdf_path], meta=meta)["documents"] @@ -149,18 +133,14 @@ def test_run_one_doc_per_file_with_meta(self, samples_path): assert documents[0].meta["file_path"] == str(pdf_path) assert "custom_meta" in documents[0].meta assert documents[0].meta["custom_meta"] == "foobar" - assert documents[0].meta == { - "file_path": str(pdf_path), - "custom_meta": "foobar", - } + assert documents[0].meta == {"file_path": str(pdf_path), "custom_meta": "foobar"} @pytest.mark.integration def test_run_one_doc_per_page_with_meta(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" meta = {"custom_meta": "foobar"} local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-page", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-page" ) documents = local_converter.run(paths=[pdf_path], meta=meta)["documents"] @@ -176,8 +156,7 @@ def test_run_one_doc_per_element_with_meta(self, samples_path): pdf_path = samples_path / "sample_pdf.pdf" meta = {"custom_meta": "foobar"} local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-element", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) documents = local_converter.run(paths=[pdf_path], meta=meta)["documents"] @@ -203,8 +182,7 @@ def test_run_one_doc_per_element_with_meta_list_two_files(self, samples_path): {"custom_meta": "sample_pdf2.pdf", "common_meta": "common"}, ] local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-element", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) documents = local_converter.run(paths=pdf_path, meta=meta)["documents"] @@ -222,13 +200,9 @@ def test_run_one_doc_per_element_with_meta_list_two_files(self, samples_path): @pytest.mark.integration def test_run_one_doc_per_element_with_meta_list_folder_fail(self, samples_path): pdf_path = [samples_path] - meta = [ - {"custom_meta": "foobar", "common_meta": "common"}, - {"other_meta": "barfoo", "common_meta": "common"}, - ] + meta = [{"custom_meta": "foobar", "common_meta": "common"}, {"other_meta": "barfoo", "common_meta": "common"}] local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-element", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) with pytest.raises(ValueError): local_converter.run(paths=pdf_path, meta=meta)["documents"] @@ -239,8 +213,7 @@ def test_run_one_doc_per_element_with_meta_list_folder(self, samples_path): meta = {"common_meta": "common"} local_converter = UnstructuredFileConverter( - api_url="http://localhost:8000/general/v0/general", - document_creation_mode="one-doc-per-element", + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) documents = local_converter.run(paths=pdf_path, meta=meta)["documents"] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py index e9f804c97..fec0b81e6 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py @@ -49,9 +49,7 @@ def __init__( self._filters = filters or {} self._top_k = top_k self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -85,19 +83,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run( - self, - query: str, - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ): + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """ Retrieves documents from Weaviate using the BM25 algorithm. @@ -112,7 +103,5 @@ def run( filters = apply_filter_policy(self._filter_policy, self._filters, filters) top_k = top_k or self._top_k - documents = self._document_store._bm25_retrieval( - query=query, filters=filters, top_k=top_k - ) + documents = self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k) return {"documents": documents} diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py index 14617126f..8688b4145 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py @@ -56,9 +56,7 @@ def __init__( self._distance = distance self._certainty = certainty self._filter_policy = ( - filter_policy - if isinstance(filter_policy, FilterPolicy) - else FilterPolicy.from_str(filter_policy) + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) def to_dict(self) -> Dict[str, Any]: @@ -95,9 +93,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever": # Pipelines serialized with old versions of the component might not # have the filter_policy field. if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str( - filter_policy - ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py index 1338951b7..87c7b6b01 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py @@ -1,13 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .auth import ( - AuthApiKey, - AuthBearerToken, - AuthClientCredentials, - AuthClientPassword, - AuthCredentials, -) +from .auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword, AuthCredentials from .document_store import WeaviateDocumentStore __all__ = [ diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py index a3480937b..803274aa4 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py @@ -132,9 +132,7 @@ def _greater_than(field: str, value: Any) -> FilterReturn: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return weaviate.classes.query.Filter.by_property(field).greater_than( - _handle_date(value) - ) + return weaviate.classes.query.Filter.by_property(field).greater_than(_handle_date(value)) def _greater_than_equal(field: str, value: Any) -> FilterReturn: @@ -156,9 +154,7 @@ def _greater_than_equal(field: str, value: Any) -> FilterReturn: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return weaviate.classes.query.Filter.by_property(field).greater_or_equal( - _handle_date(value) - ) + return weaviate.classes.query.Filter.by_property(field).greater_or_equal(_handle_date(value)) def _less_than(field: str, value: Any) -> FilterReturn: @@ -180,9 +176,7 @@ def _less_than(field: str, value: Any) -> FilterReturn: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return weaviate.classes.query.Filter.by_property(field).less_than( - _handle_date(value) - ) + return weaviate.classes.query.Filter.by_property(field).less_than(_handle_date(value)) def _less_than_equal(field: str, value: Any) -> FilterReturn: @@ -204,9 +198,7 @@ def _less_than_equal(field: str, value: Any) -> FilterReturn: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return weaviate.classes.query.Filter.by_property(field).less_or_equal( - _handle_date(value) - ) + return weaviate.classes.query.Filter.by_property(field).less_or_equal(_handle_date(value)) def _in(field: str, value: Any) -> FilterReturn: @@ -221,9 +213,7 @@ def _not_in(field: str, value: Any) -> FilterReturn: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" raise FilterError(msg) - operands = [ - weaviate.classes.query.Filter.by_property(field).not_equal(v) for v in value - ] + operands = [weaviate.classes.query.Filter.by_property(field).not_equal(v) for v in value] return Filter.all_of(operands) @@ -270,8 +260,5 @@ def _match_no_document(field: str) -> FilterReturn: between different Document Stores. """ - operands = [ - weaviate.classes.query.Filter.by_property(field).is_none(val) - for val in [False, True] - ] + operands = [weaviate.classes.query.Filter.by_property(field).is_none(val) for val in [False, True]] return Filter.all_of(operands) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index 702cedea7..33bc30159 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -58,10 +58,7 @@ def to_dict(self) -> Dict[str, Any]: else: _fields[_field.name] = getattr(self, _field.name) - return { - "type": str(SupportedAuthTypes.from_class(self.__class__)), - "init_parameters": _fields, - } + return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields} @staticmethod def from_dict(data: Dict[str, Any]) -> "AuthCredentials": @@ -104,9 +101,7 @@ class AuthApiKey(AuthCredentials): By default it will load `api_key` from the environment variable `WEAVIATE_API_KEY`. """ - api_key: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"]) - ) + api_key: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"])) @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey": @@ -127,21 +122,13 @@ class AuthBearerToken(AuthCredentials): `WEAVIATE_REFRESH_TOKEN` environment variable is optional. """ - access_token: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"]) - ) + access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"])) expires_in: int = field(default=60) - refresh_token: Secret = field( - default_factory=lambda: Secret.from_env_var( - ["WEAVIATE_REFRESH_TOKEN"], strict=False - ) - ) + refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False)) @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken": - deserialize_secrets_inplace( - data["init_parameters"], ["access_token", "refresh_token"] - ) + deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"]) return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthBearerToken: @@ -165,12 +152,8 @@ class AuthClientCredentials(AuthCredentials): separated strings. e.g "scope1" or "scope1 scope2". """ - client_secret: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"]) - ) - scope: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False) - ) + client_secret: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials": @@ -195,21 +178,13 @@ class AuthClientPassword(AuthCredentials): separated strings. e.g "scope1" or "scope1 scope2". """ - username: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"]) - ) - password: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"]) - ) - scope: Secret = field( - default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False) - ) + username: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"])) + password: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword": - deserialize_secrets_inplace( - data["init_parameters"], ["username", "password", "scope"] - ) + deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"]) return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthClientPassword: diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 767d697ba..82088dd89 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -172,19 +172,13 @@ def client(self): if self._client: return self._client - if ( - self._url - and self._url.startswith("http") - and self._url.endswith(".weaviate.network") - ): + if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"): # We use this utility function instead of using WeaviateClient directly like in other cases # otherwise we'd have to parse the URL to get some information about the connection. # This utility function does all that for us. self._client = weaviate.connect_to_wcs( self._url, - auth_credentials=self._auth_client_secret.resolve_value() - if self._auth_client_secret - else None, + auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, headers=self._additional_headers, additional_config=self._additional_config, ) @@ -194,16 +188,12 @@ def client(self): self._client = weaviate.WeaviateClient( connection_params=( weaviate.connect.base.ConnectionParams.from_url( - url=self._url, - grpc_port=self._grpc_port, - grpc_secure=self._grpc_secure, + url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure ) if self._url else None ), - auth_client_secret=self._auth_client_secret.resolve_value() - if self._auth_client_secret - else None, + auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, additional_config=self._additional_config, additional_headers=self._additional_headers, embedded_options=self._embedded_options, @@ -235,22 +225,16 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - embedded_options = ( - asdict(self._embedded_options) if self._embedded_options else None - ) + embedded_options = asdict(self._embedded_options) if self._embedded_options else None additional_config = ( - json.loads(self._additional_config.model_dump_json(by_alias=True)) - if self._additional_config - else None + json.loads(self._additional_config.model_dump_json(by_alias=True)) if self._additional_config else None ) return default_to_dict( self, url=self._url, collection_settings=self._collection_settings, - auth_client_secret=self._auth_client_secret.to_dict() - if self._auth_client_secret - else None, + auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None, additional_headers=self._additional_headers, embedded_options=embedded_options, additional_config=additional_config, @@ -266,24 +250,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": :returns: The deserialized component. """ - if ( - auth_client_secret := data["init_parameters"].get("auth_client_secret") - ) is not None: - data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict( - auth_client_secret - ) - if ( - embedded_options := data["init_parameters"].get("embedded_options") - ) is not None: - data["init_parameters"]["embedded_options"] = EmbeddedOptions( - **embedded_options - ) - if ( - additional_config := data["init_parameters"].get("additional_config") - ) is not None: - data["init_parameters"]["additional_config"] = AdditionalConfig( - **additional_config - ) + if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: + data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret) + if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: + data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) + if (additional_config := data["init_parameters"].get("additional_config")) is not None: + data["init_parameters"]["additional_config"] = AdditionalConfig(**additional_config) return default_from_dict( cls, data, @@ -367,9 +339,7 @@ def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document: def _query(self) -> List[Dict[str, Any]]: properties = [p.name for p in self.collection.config.get().properties] try: - result = self.collection.iterator( - include_vector=True, return_properties=properties - ) + result = self.collection.iterator(include_vector=True, return_properties=properties) except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to query documents in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e @@ -390,9 +360,7 @@ def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: partial_result = None result = [] # Keep querying until we get all documents matching the filters - while ( - partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT - ): + while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT: try: partial_result = self.collection.query.fetch_objects( filters=convert_filters(filters), @@ -408,9 +376,7 @@ def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: offset += DEFAULT_QUERY_LIMIT return result - def filter_documents( - self, filters: Optional[Dict[str, Any]] = None - ) -> List[Document]: + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -485,9 +451,7 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self.collection.data.exists( - uuid=generate_uuid5(doc.id) - ): + if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue @@ -507,9 +471,7 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: raise DuplicateDocumentError(msg) return written - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE - ) -> int: + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ Writes documents to Weaviate using the specified policy. We recommend using a OVERWRITE policy as it's faster than other policies for Weaviate since it uses @@ -530,15 +492,10 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: The object_ids to delete. """ weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids] - self.collection.data.delete_many( - where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids) - ) + self.collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) def _bm25_retrieval( - self, - query: str, - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, + self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ) -> List[Document]: properties = [p.name for p in self.collection.config.get().properties] result = self.collection.query.bm25( diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py index 60db31e02..3ad75e206 100644 --- a/integrations/weaviate/tests/test_auth.py +++ b/integrations/weaviate/tests/test_auth.py @@ -25,13 +25,7 @@ def test_to_dict(self): credentials = AuthApiKey() assert credentials.to_dict() == { "type": "api_key", - "init_parameters": { - "api_key": { - "env_vars": ["WEAVIATE_API_KEY"], - "strict": True, - "type": "env_var", - } - }, + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, } def test_from_dict(self, monkeypatch): @@ -39,13 +33,7 @@ def test_from_dict(self, monkeypatch): credentials = AuthCredentials.from_dict( { "type": "api_key", - "init_parameters": { - "api_key": { - "env_vars": ["WEAVIATE_API_KEY"], - "strict": True, - "type": "env_var", - } - }, + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, } ) assert isinstance(credentials, AuthApiKey) @@ -74,17 +62,9 @@ def test_to_dict(self): assert credentials.to_dict() == { "type": "bearer", "init_parameters": { - "access_token": { - "env_vars": ["WEAVIATE_ACCESS_TOKEN"], - "strict": True, - "type": "env_var", - }, + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "expires_in": 60, - "refresh_token": { - "env_vars": ["WEAVIATE_REFRESH_TOKEN"], - "strict": False, - "type": "env_var", - }, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, }, } @@ -93,17 +73,9 @@ def test_from_dict(self): { "type": "bearer", "init_parameters": { - "access_token": { - "env_vars": ["WEAVIATE_ACCESS_TOKEN"], - "strict": True, - "type": "env_var", - }, + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "expires_in": 10, - "refresh_token": { - "env_vars": ["WEAVIATE_REFRESH_TOKEN"], - "strict": False, - "type": "env_var", - }, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, }, } ) @@ -137,16 +109,8 @@ def test_to_dict(self): assert credentials.to_dict() == { "type": "client_credentials", "init_parameters": { - "client_secret": { - "env_vars": ["WEAVIATE_CLIENT_SECRET"], - "strict": True, - "type": "env_var", - }, - "scope": { - "env_vars": ["WEAVIATE_SCOPE"], - "strict": False, - "type": "env_var", - }, + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, }, } @@ -155,16 +119,8 @@ def test_from_dict(self): { "type": "client_credentials", "init_parameters": { - "client_secret": { - "env_vars": ["WEAVIATE_CLIENT_SECRET"], - "strict": True, - "type": "env_var", - }, - "scope": { - "env_vars": ["WEAVIATE_SCOPE"], - "strict": False, - "type": "env_var", - }, + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, }, } ) @@ -198,21 +154,9 @@ def test_to_dict(self): assert credentials.to_dict() == { "type": "client_password", "init_parameters": { - "username": { - "env_vars": ["WEAVIATE_USERNAME"], - "strict": True, - "type": "env_var", - }, - "password": { - "env_vars": ["WEAVIATE_PASSWORD"], - "strict": True, - "type": "env_var", - }, - "scope": { - "env_vars": ["WEAVIATE_SCOPE"], - "strict": False, - "type": "env_var", - }, + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, }, } @@ -221,21 +165,9 @@ def test_from_dict(self): { "type": "client_password", "init_parameters": { - "username": { - "env_vars": ["WEAVIATE_USERNAME"], - "strict": True, - "type": "env_var", - }, - "password": { - "env_vars": ["WEAVIATE_PASSWORD"], - "strict": True, - "type": "env_var", - }, - "scope": { - "env_vars": ["WEAVIATE_SCOPE"], - "strict": False, - "type": "env_var", - }, + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, }, } ) diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py index ef3bbd811..3720daa85 100644 --- a/integrations/weaviate/tests/test_bm25_retriever.py +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -18,15 +18,11 @@ def test_init_default(): assert retriever._top_k == 10 assert retriever._filter_policy == FilterPolicy.REPLACE - retriever = WeaviateBM25Retriever( - document_store=mock_document_store, filter_policy="replace" - ) + retriever = WeaviateBM25Retriever(document_store=mock_document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - WeaviateBM25Retriever( - document_store=mock_document_store, filter_policy="keep_all" - ) + WeaviateBM25Retriever(document_store=mock_document_store, filter_policy="keep_all") @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") @@ -143,14 +139,10 @@ def test_from_dict_no_filter_policy(_mock_weaviate): assert retriever._filter_policy == FilterPolicy.REPLACE -@patch( - "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore" -) +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") def test_run(mock_document_store): retriever = WeaviateBM25Retriever(document_store=mock_document_store) query = "some query" filters = {"field": "content", "operator": "==", "value": "Some text"} retriever.run(query=query, filters=filters, top_k=5) - mock_document_store._bm25_retrieval.assert_called_once_with( - query=query, filters=filters, top_k=5 - ) + mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 86c5b4970..068212686 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -42,18 +42,14 @@ ) -@patch( - "haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient" -) +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") def test_init_is_lazy(_mock_client): _ = WeaviateDocumentStore() _mock_client.assert_not_called() @pytest.mark.integration -class TestWeaviateDocumentStore( - CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest -): +class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture def document_store(self, request) -> WeaviateDocumentStore: # Use a different index for each test so we can run them in parallel @@ -125,35 +121,19 @@ def filterable_docs(self) -> List[Document]: documents.append( Document( content=f"Document {i} without embedding", - meta={ - "name": f"name_{i}", - "no_embedding": True, - "chapter": "conclusion", - }, + meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, ) ) + documents.append(Document(dataframe=DataFrame([i]), meta={"name": f"table_doc_{i}"})) documents.append( - Document(dataframe=DataFrame([i]), meta={"name": f"table_doc_{i}"}) - ) - documents.append( - Document( - content=f"Doc {i} with zeros emb", - meta={"name": "zeros_doc"}, - embedding=TEST_EMBEDDING_1, - ) + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) ) documents.append( - Document( - content=f"Doc {i} with ones emb", - meta={"name": "ones_doc"}, - embedding=TEST_EMBEDDING_2, - ) + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) ) return documents - def assert_documents_are_equal( - self, received: List[Document], expected: List[Document] - ): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): assert len(received) == len(expected) received = sorted(received, key=lambda doc: doc.id) expected = sorted(expected, key=lambda doc: doc.id) @@ -179,9 +159,7 @@ def assert_documents_are_equal( for key in meta_keys: assert received_meta.get(key) == expected_meta.get(key) - @patch( - "haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient" - ) + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") def test_connection(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() mock_client.collections.exists.return_value = False @@ -268,11 +246,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "auth_client_secret": { "type": "api_key", "init_parameters": { - "api_key": { - "env_vars": ["WEAVIATE_API_KEY"], - "strict": True, - "type": "env_var", - } + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -291,11 +265,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "session_pool_maxsize": 100, "session_pool_max_retries": 3, }, - "proxies": { - "http": "http://proxy:1234", - "https": None, - "grpc": None, - }, + "proxies": {"http": "http://proxy:1234", "https": None, "grpc": None}, "timeout": [30, 90], "trust_env": False, }, @@ -314,16 +284,10 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "auth_client_secret": { "type": "api_key", "init_parameters": { - "api_key": { - "env_vars": ["WEAVIATE_API_KEY"], - "strict": True, - "type": "env_var", - } + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, - "additional_headers": { - "X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY" - }, + "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, "embedded_options": { "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, "binary_path": DEFAULT_BINARY_PATH, @@ -361,26 +325,17 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): } assert document_store._auth_client_secret == AuthApiKey() assert document_store._additional_config.timeout == Timeout(query=10, insert=60) - assert document_store._additional_config.proxies == Proxies( - http="http://proxy:1234", https=None, grpc=None - ) + assert document_store._additional_config.proxies == Proxies(http="http://proxy:1234", https=None, grpc=None) assert not document_store._additional_config.trust_env - assert document_store._additional_headers == { - "X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY" - } - assert ( - document_store._embedded_options.persistence_data_path - == DEFAULT_PERSISTENCE_DATA_PATH - ) + assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"} + assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH assert document_store._embedded_options.version == "1.23.0" assert document_store._embedded_options.port == DEFAULT_PORT assert document_store._embedded_options.hostname == "127.0.0.1" assert document_store._embedded_options.additional_env_vars is None assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT - assert ( - document_store._additional_config.connection.session_pool_connections == 20 - ) + assert document_store._additional_config.connection.session_pool_connections == 20 assert document_store._additional_config.connection.session_pool_maxsize == 20 def test_to_data_object(self, document_store, test_files_path): @@ -393,9 +348,7 @@ def test_to_data_object(self, document_store, test_files_path): "score": None, } - image = ByteStream.from_file_path( - test_files_path / "robot1.jpg", mime_type="image/jpeg" - ) + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") doc = Document( content="test doc", blob=image, @@ -414,9 +367,7 @@ def test_to_data_object(self, document_store, test_files_path): } def test_to_document(self, document_store, test_files_path): - image = ByteStream.from_file_path( - test_files_path / "robot1.jpg", mime_type="image/jpeg" - ) + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") data = DataObject( properties={ "_original_id": "123", @@ -451,16 +402,12 @@ def test_write_documents(self, document_store): assert document_store.count_documents() == 1 def test_write_documents_with_blob_data(self, document_store, test_files_path): - image = ByteStream.from_file_path( - test_files_path / "robot1.jpg", mime_type="image/jpeg" - ) + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") doc = Document(content="test doc", blob=image) assert document_store.write_documents([doc]) == 1 def test_filter_documents_with_blob_data(self, document_store, test_files_path): - image = ByteStream.from_file_path( - test_files_path / "robot1.jpg", mime_type="image/jpeg" - ) + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") doc = Document(content="test doc", blob=image) assert document_store.write_documents([doc]) == 1 @@ -469,9 +416,7 @@ def test_filter_documents_with_blob_data(self, document_store, test_files_path): assert len(docs) == 1 assert docs[0].blob == image - def test_comparison_greater_than_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): """ This test has been copied from haystack/testing/document_store.py and modified to use a different date format. @@ -488,14 +433,11 @@ def test_comparison_greater_than_with_iso_date( d for d in filterable_docs if d.meta.get("date") is not None - and parser.isoparse(d.meta["date"]) - > parser.isoparse("1972-12-11T19:54:58Z") + and parser.isoparse(d.meta["date"]) > parser.isoparse("1972-12-11T19:54:58Z") ], ) - def test_comparison_greater_than_equal_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): """ This test has been copied from haystack/testing/document_store.py and modified to use a different date format. @@ -512,8 +454,7 @@ def test_comparison_greater_than_equal_with_iso_date( d for d in filterable_docs if d.meta.get("date") is not None - and parser.isoparse(d.meta["date"]) - >= parser.isoparse("1969-07-21T20:17:40Z") + and parser.isoparse(d.meta["date"]) >= parser.isoparse("1969-07-21T20:17:40Z") ], ) @@ -534,14 +475,11 @@ def test_comparison_less_than_with_iso_date(self, document_store, filterable_doc d for d in filterable_docs if d.meta.get("date") is not None - and parser.isoparse(d.meta["date"]) - < parser.isoparse("1969-07-21T20:17:40Z") + and parser.isoparse(d.meta["date"]) < parser.isoparse("1969-07-21T20:17:40Z") ], ) - def test_comparison_less_than_equal_with_iso_date( - self, document_store, filterable_docs - ): + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): """ This test has been copied from haystack/testing/document_store.py and modified to use a different date format. @@ -558,16 +496,13 @@ def test_comparison_less_than_equal_with_iso_date( d for d in filterable_docs if d.meta.get("date") is not None - and parser.isoparse(d.meta["date"]) - <= parser.isoparse("1969-07-21T20:17:40Z") + and parser.isoparse(d.meta["date"]) <= parser.isoparse("1969-07-21T20:17:40Z") ], ) @pytest.mark.skip(reason="Weaviate for some reason is not returning what we expect") def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - return super().test_comparison_not_equal_with_dataframe( - document_store, filterable_docs - ) + return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) def test_bm25_retrieval(self, document_store): document_store.write_documents( @@ -656,9 +591,7 @@ def test_embedding_retrieval(self, document_store): Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), ] ) - result = document_store._embedding_retrieval( - query_embedding=[1.0, 1.0, 1.0, 1.0] - ) + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0]) assert len(result) == 3 assert "The document" == result[0].content assert result[0].score > 0.0 @@ -679,9 +612,7 @@ def test_embedding_retrieval_with_filters(self, document_store): ] ) filters = {"field": "content", "operator": "==", "value": "The document I want"} - result = document_store._embedding_retrieval( - query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters - ) + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters) assert len(result) == 1 assert "The document I want" == result[0].content assert result[0].score > 0.0 @@ -690,15 +621,10 @@ def test_embedding_retrieval_with_topk(self, document_store): docs = [ Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Yet another document", - embedding=[0.00001, 0.00001, 0.00001, 0.00002], - ), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), ] document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2 - ) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2) assert len(results) == 2 assert results[0].content == "The document" assert results[0].score > 0.0 @@ -709,15 +635,10 @@ def test_embedding_retrieval_with_distance(self, document_store): docs = [ Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Yet another document", - embedding=[0.00001, 0.00001, 0.00001, 0.00002], - ), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), ] document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0 - ) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0) assert len(results) == 1 assert results[0].content == "The document" assert results[0].score > 0.0 @@ -726,33 +647,24 @@ def test_embedding_retrieval_with_certainty(self, document_store): docs = [ Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Yet another document", - embedding=[0.00001, 0.00001, 0.00001, 0.00002], - ), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), ] document_store.write_documents(docs) - results = document_store._embedding_retrieval( - query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0 - ) + results = document_store._embedding_retrieval(query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0) assert len(results) == 1 assert results[0].content == "Another document" assert results[0].score > 0.0 def test_embedding_retrieval_with_distance_and_certainty(self, document_store): with pytest.raises(ValueError): - document_store._embedding_retrieval( - query_embedding=[], distance=0.1, certainty=0.1 - ) + document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) def test_filter_documents_with_legacy_filters(self, document_store): docs = [] for index in range(10): docs.append(Document(content="This is some content", meta={"index": index})) document_store.write_documents(docs) - result = document_store.filter_documents( - {"content": {"$eq": "This is some content"}} - ) + result = document_store.filter_documents({"content": {"$eq": "This is some content"}}) assert len(result) == 10 @@ -773,9 +685,7 @@ def test_filter_documents_over_default_limit(self, document_store): docs.append(Document(content="This is some content", meta={"index": index})) document_store.write_documents(docs) with pytest.raises(DocumentStoreError): - document_store.filter_documents( - {"field": "content", "operator": "==", "value": "This is some content"} - ) + document_store.filter_documents({"field": "content", "operator": "==", "value": "This is some content"}) def test_schema_class_name_conversion_preserves_pascal_case(self): collection_settings = {"class": "CaseDocument"} @@ -793,16 +703,13 @@ def test_schema_class_name_conversion_preserves_pascal_case(self): assert doc_score._collection_settings["class"] == "Lower_case_name" @pytest.mark.skipif( - not os.environ.get("WEAVIATE_API_KEY", None) - and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None), + not os.environ.get("WEAVIATE_API_KEY", None) and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None), reason="Both WEAVIATE_API_KEY and WEAVIATE_CLOUD_CLUSTER_URL are not set. Skipping test.", ) def test_connect_to_weaviate_cloud(self): document_store = WeaviateDocumentStore( url=os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL"), - auth_client_secret=AuthApiKey( - api_key=Secret.from_env_var("WEAVIATE_API_KEY") - ), + auth_client_secret=AuthApiKey(api_key=Secret.from_env_var("WEAVIATE_API_KEY")), ) assert document_store.client diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py index 3571d82ed..13f214dd1 100644 --- a/integrations/weaviate/tests/test_embedding_retriever.py +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -6,9 +6,7 @@ import pytest from haystack.document_stores.types import FilterPolicy -from haystack_integrations.components.retrievers.weaviate import ( - WeaviateEmbeddingRetriever, -) +from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -22,23 +20,17 @@ def test_init_default(): assert retriever._distance is None assert retriever._certainty is None - retriever = WeaviateEmbeddingRetriever( - document_store=mock_document_store, filter_policy="replace" - ) + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE with pytest.raises(ValueError): - WeaviateEmbeddingRetriever( - document_store=mock_document_store, filter_policy="keep_all" - ) + WeaviateEmbeddingRetriever(document_store=mock_document_store, filter_policy="keep_all") def test_init_with_distance_and_certainty(): mock_document_store = Mock(spec=WeaviateDocumentStore) with pytest.raises(ValueError): - WeaviateEmbeddingRetriever( - document_store=mock_document_store, distance=0.1, certainty=0.8 - ) + WeaviateEmbeddingRetriever(document_store=mock_document_store, distance=0.1, certainty=0.8) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") @@ -165,20 +157,12 @@ def test_from_dict_no_filter_policy(_mock_weaviate): assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE -@patch( - "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore" -) +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") def test_run(mock_document_store): retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) query_embedding = [0.1, 0.1, 0.1, 0.1] filters = {"field": "content", "operator": "==", "value": "Some text"} - retriever.run( - query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1 - ) + retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1) mock_document_store._embedding_retrieval.assert_called_once_with( - query_embedding=query_embedding, - filters=filters, - top_k=5, - distance=0.1, - certainty=None, + query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=None ) diff --git a/nodes/text2speech/tests/test_nodes.py b/nodes/text2speech/tests/test_nodes.py index c56bf7bec..931d3f757 100644 --- a/nodes/text2speech/tests/test_nodes.py +++ b/nodes/text2speech/tests/test_nodes.py @@ -40,9 +40,7 @@ def transcribe(self, media_file: str): ) data = np.frombuffer(output, np.int16).flatten().astype(np.float32) / 32768.0 - features = self._processor( - data, sampling_rate=16000, return_tensors="pt" - ).input_features + features = self._processor(data, sampling_rate=16000, return_tensors="pt").input_features tokens = self._model.generate(features) return self._processor.batch_decode(tokens, skip_special_tokens=True) @@ -82,9 +80,7 @@ def test_text_to_speech_audio_file(self, tmp_path, whisper_helper: WhisperHelper transformers_params={"seed": 4535, "always_fix_seed": True}, ) - audio_file = text2speech.text_to_audio_file( - text="answer", generated_audio_dir=tmp_path / "test_audio" - ) + audio_file = text2speech.text_to_audio_file(text="answer", generated_audio_dir=tmp_path / "test_audio") assert os.path.exists(audio_file) expected_doc = whisper_helper.transcribe(str(SAMPLES_PATH / "answer.wav")) @@ -93,18 +89,14 @@ def test_text_to_speech_audio_file(self, tmp_path, whisper_helper: WhisperHelper assert expected_doc[0] in generated_doc[0] @pytest.mark.xfail(reason="known issue converting to MP3") - def test_text_to_speech_compress_audio( - self, tmp_path, whisper_helper: WhisperHelper - ): + def test_text_to_speech_compress_audio(self, tmp_path, whisper_helper: WhisperHelper): text2speech = TextToSpeech( model_name_or_path="espnet/kan-bayashi_ljspeech_vits", transformers_params={"seed": 4535, "always_fix_seed": True}, ) expected_audio_file = SAMPLES_PATH / "answer.wav" audio_file = text2speech.text_to_audio_file( - text="answer", - generated_audio_dir=tmp_path / "test_audio", - audio_format="mp3", + text="answer", generated_audio_dir=tmp_path / "test_audio", audio_format="mp3" ) assert os.path.exists(audio_file) assert audio_file.suffix == ".mp3" @@ -114,18 +106,14 @@ def test_text_to_speech_compress_audio( assert expected_doc[0] in generated_doc[0] - def test_text_to_speech_naming_function( - self, tmp_path, whisper_helper: WhisperHelper - ): + def test_text_to_speech_naming_function(self, tmp_path, whisper_helper: WhisperHelper): text2speech = TextToSpeech( model_name_or_path="espnet/kan-bayashi_ljspeech_vits", transformers_params={"seed": 4535, "always_fix_seed": True}, ) expected_audio_file = SAMPLES_PATH / "answer.wav" audio_file = text2speech.text_to_audio_file( - text="answer", - generated_audio_dir=tmp_path / "test_audio", - audio_naming_function=lambda text: text, + text="answer", generated_audio_dir=tmp_path / "test_audio", audio_naming_function=lambda text: text ) assert os.path.exists(audio_file) assert audio_file.name == expected_audio_file.name @@ -148,9 +136,7 @@ def test_answer_to_speech(self, tmp_path, whisper_helper: WhisperHelper): meta={"some_meta": "some_value"}, ) expected_audio_answer = SAMPLES_PATH / "answer.wav" - expected_audio_context = ( - SAMPLES_PATH / "the context for this answer is here.wav" - ) + expected_audio_context = SAMPLES_PATH / "the context for this answer is here.wav" answer2speech = AnswerToSpeech( generated_audio_dir=tmp_path / "test_audio", @@ -161,20 +147,12 @@ def test_answer_to_speech(self, tmp_path, whisper_helper: WhisperHelper): audio_answer: Answer = results["answers"][0] assert isinstance(audio_answer, Answer) - assert ( - audio_answer.answer.split(os.path.sep)[-1] - == str(expected_audio_answer).split(os.path.sep)[-1] - ) - assert ( - audio_answer.context.split(os.path.sep)[-1] - == str(expected_audio_context).split(os.path.sep)[-1] - ) + assert audio_answer.answer.split(os.path.sep)[-1] == str(expected_audio_answer).split(os.path.sep)[-1] + assert audio_answer.context.split(os.path.sep)[-1] == str(expected_audio_context).split(os.path.sep)[-1] assert audio_answer.offsets_in_document == [Span(31, 37)] assert audio_answer.offsets_in_context == [Span(21, 27)] assert audio_answer.meta["answer_text"] == "answer" - assert ( - audio_answer.meta["context_text"] == "the context for this answer is here" - ) + assert audio_answer.meta["context_text"] == "the context for this answer is here" assert audio_answer.meta["some_meta"] == "some_value" assert audio_answer.meta["audio_format"] == "wav" @@ -188,13 +166,9 @@ def test_answer_to_speech(self, tmp_path, whisper_helper: WhisperHelper): class TestDocumentToSpeech: def test_document_to_speech(self, tmp_path, whisper_helper: WhisperHelper): text_doc = Document( - content="this is the content of the document", - content_type="text", - meta={"name": "test_document.txt"}, - ) - expected_audio_content = ( - SAMPLES_PATH / "this is the content of the document.wav" + content="this is the content of the document", content_type="text", meta={"name": "test_document.txt"} ) + expected_audio_content = SAMPLES_PATH / "this is the content of the document.wav" doc2speech = DocumentToSpeech( generated_audio_dir=tmp_path / "test_audio", @@ -207,10 +181,7 @@ def test_document_to_speech(self, tmp_path, whisper_helper: WhisperHelper): audio_doc: Document = results["documents"][0] assert isinstance(audio_doc, Document) assert audio_doc.content_type == "audio" - assert ( - audio_doc.content.split(os.path.sep)[-1] - == str(expected_audio_content).split(os.path.sep)[-1] - ) + assert audio_doc.content.split(os.path.sep)[-1] == str(expected_audio_content).split(os.path.sep)[-1] assert audio_doc.meta["content_text"] == "this is the content of the document" assert audio_doc.meta["name"] == "test_document.txt" assert audio_doc.meta["audio_format"] == "wav" diff --git a/nodes/text2speech/text2speech/answer_to_speech.py b/nodes/text2speech/text2speech/answer_to_speech.py index df6a21fde..7d876d935 100644 --- a/nodes/text2speech/text2speech/answer_to_speech.py +++ b/nodes/text2speech/text2speech/answer_to_speech.py @@ -62,9 +62,7 @@ def __init__( """ super().__init__() self.converter = TextToSpeech( - model_name_or_path=model_name_or_path, - transformers_params=transformers_params, - devices=devices, + model_name_or_path=model_name_or_path, transformers_params=transformers_params, devices=devices ) self.generated_audio_dir = generated_audio_dir self.params: Dict[str, Any] = audio_params or {} @@ -72,19 +70,13 @@ def __init__( def run(self, answers: List[Answer]) -> Tuple[Dict[str, List[Answer]], str]: # type: ignore audio_answers = [] - for answer in tqdm( - answers, disable=not self.progress_bar, desc="Converting answers to audio" - ): + for answer in tqdm(answers, disable=not self.progress_bar, desc="Converting answers to audio"): answer_audio = self.converter.text_to_audio_file( - text=answer.answer, - generated_audio_dir=self.generated_audio_dir, - **self.params + text=answer.answer, generated_audio_dir=self.generated_audio_dir, **self.params ) if isinstance(answer.context, str): context_audio = self.converter.text_to_audio_file( - text=answer.context, - generated_audio_dir=self.generated_audio_dir, - **self.params + text=answer.context, generated_audio_dir=self.generated_audio_dir, **self.params ) audio_answer = Answer.from_dict(answer.to_dict()) @@ -94,9 +86,7 @@ def run(self, answers: List[Answer]) -> Tuple[Dict[str, List[Answer]], str]: # { "answer_text": answer.answer, "context_text": answer.context, - "audio_format": self.params.get( - "audio_format", answer_audio.suffix.replace(".", "") - ), + "audio_format": self.params.get("audio_format", answer_audio.suffix.replace(".", "")), "sample_rate": self.converter.model.fs, } ) diff --git a/nodes/text2speech/text2speech/document_to_speech.py b/nodes/text2speech/text2speech/document_to_speech.py index b7e20585b..1b73c6bd7 100644 --- a/nodes/text2speech/text2speech/document_to_speech.py +++ b/nodes/text2speech/text2speech/document_to_speech.py @@ -50,10 +50,7 @@ def __init__( :param transformers_params: The parameters to pass over to the `Text2Speech.from_pretrained()` call. """ super().__init__() - self.converter = TextToSpeech( - model_name_or_path=model_name_or_path, - transformers_params=transformers_params, - ) + self.converter = TextToSpeech(model_name_or_path=model_name_or_path, transformers_params=transformers_params) self.generated_audio_dir = generated_audio_dir self.params: Dict[str, Any] = audio_params or {} @@ -61,9 +58,7 @@ def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str audio_documents = [] for doc in tqdm(documents): content_audio = self.converter.text_to_audio_file( - text=doc.content, - generated_audio_dir=self.generated_audio_dir, - **self.params + text=doc.content, generated_audio_dir=self.generated_audio_dir, **self.params ) audio_document = Document.from_dict(doc.to_dict()) audio_document.content = str(content_audio) @@ -71,9 +66,7 @@ def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str audio_document.meta.update( { "content_text": doc.content, - "audio_format": self.params.get( - "audio_format", content_audio.suffix.replace(".", "") - ), + "audio_format": self.params.get("audio_format", content_audio.suffix.replace(".", "")), "sample_rate": self.converter.model.fs, } ) diff --git a/nodes/text2speech/text2speech/utils/text_to_speech.py b/nodes/text2speech/text2speech/utils/text_to_speech.py index 4861c4cdd..e7d22cc11 100644 --- a/nodes/text2speech/text2speech/utils/text_to_speech.py +++ b/nodes/text2speech/text2speech/utils/text_to_speech.py @@ -58,9 +58,7 @@ def __init__( """ super().__init__() - resolved_devices, _ = initialize_device_settings( - devices=devices, use_cuda=use_gpu, multi_gpu=False - ) + resolved_devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False) if len(resolved_devices) > 1: logger.warning( "Multiple devices are not supported in %s inference, using the first device %s.", @@ -69,9 +67,7 @@ def __init__( ) self.model = _Text2SpeechModel.from_pretrained( - str(model_name_or_path), - device=resolved_devices[0].type, - **(transformers_params or {}), + str(model_name_or_path), device=resolved_devices[0].type, **(transformers_params or {}) ) def text_to_audio_file( @@ -84,9 +80,7 @@ def text_to_audio_file( channels_count: int = 1, bitrate: str = "320k", normalized=True, - audio_naming_function: Callable = lambda text: hashlib.md5( - text.encode("utf-8") - ).hexdigest(), + audio_naming_function: Callable = lambda text: hashlib.md5(text.encode("utf-8")).hexdigest(), ) -> Path: """ Convert an input string into an audio file containing the same string read out loud. @@ -124,11 +118,7 @@ def text_to_audio_file( audio_data = self.text_to_audio_data(text) if audio_format.upper() in sf.available_formats().keys(): sf.write( - data=audio_data, - file=file_path, - format=audio_format, - subtype=subtype, - samplerate=self.model.fs, + data=audio_data, file=file_path, format=audio_format, subtype=subtype, samplerate=self.model.fs ) else: self.compress_audio( @@ -191,10 +181,5 @@ def compress_audio( :param normalized: Normalizes the audio before compression (range 2^15) or leaves it untouched. """ data = np.int16((data * 2**15) if normalized else data) - audio = AudioSegment( - data.tobytes(), - frame_rate=sample_rate, - sample_width=sample_width, - channels=channels_count, - ) + audio = AudioSegment(data.tobytes(), frame_rate=sample_rate, sample_width=sample_width, channels=channels_count) audio.export(path, format=audio_format, bitrate=bitrate)