diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 91c5eb13c..6e1cae85c 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -27,6 +27,11 @@ classifiers = [ dependencies = [ "haystack-ai", "transformers[sentencepiece]", + # The main export function of Optimum into ONNX has hidden dependencies. + # It depends on either "sentence-transformers", "diffusers" or "timm", based + # on which model is loaded from HF Hub. + # Ref: https://github.com/huggingface/optimum/blob/8651c0ca1cccf095458bc80329dec9df4601edb4/optimum/exporters/onnx/__main__.py#L164 + # "sentence-transformers" has been added, since most embedding models use it "sentence-transformers>=2.2.0", "optimum[onnxruntime]" ] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/backends/optimum_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/backends/optimum_backend.py index 515d92938..eca285a88 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/backends/optimum_backend.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/backends/optimum_backend.py @@ -3,6 +3,7 @@ import numpy as np import torch from haystack.utils.auth import Secret +from haystack_integrations.components.embedders.pooling import Pooling, PoolingMode from optimum.onnxruntime import ORTModelForFeatureExtraction from tqdm import tqdm from transformers import AutoTokenizer @@ -38,25 +39,11 @@ def __init__(self, model: str, token: Optional[Secret] = None, model_kwargs: Opt self.model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True) self.tokenizer = AutoTokenizer.from_pretrained(model, token=token) - def mean_pooling(self, model_output: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: - """ - Perform Mean Pooling on the output of the Embedding model. - - :param model_output: The output of the embedding model. - :param attention_mask: The attention mask of the tokenized text. - :return: The embeddings of the text after mean pooling. - """ - # First element of model_output contains all token embeddings - token_embeddings = model_output[0] - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) - sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) - return sum_embeddings / sum_mask - def embed( self, texts_to_embed: Union[str, List[str]], normalize_embeddings: bool, + pooling_mode: PoolingMode = PoolingMode.MEAN, progress_bar: bool = False, batch_size: int = 1, ) -> Union[List[List[float]], List[float]]: @@ -65,6 +52,7 @@ def embed( :param texts_to_embed: T :param normalize_embeddings: Whether to normalize the embeddings to unit length. + :param pooling_mode: The pooling mode to use. Defaults to PoolingMode.MEAN. :param progress_bar: Whether to show a progress bar or not, defaults to False. :param batch_size: Batch size to use, defaults to 1. :return: A single embedding if the input is a single string. A list of embeddings if the input is a list of @@ -97,8 +85,13 @@ def embed( # Compute token embeddings model_output = self.model(**encoded_input) - # Perform mean pooling - sentence_embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"].to(device)) + # Pool Embeddings + pooling = Pooling( + pooling_mode=pooling_mode, + attention_mask=encoded_input["attention_mask"].to(device), + model_output=model_output, + ) + sentence_embeddings = pooling.pool_embeddings() all_embeddings.extend(sentence_embeddings.tolist()) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py index 9e3ae2e61..4dee89633 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace @@ -6,8 +6,10 @@ from haystack_integrations.components.embedders.backends.optimum_backend import ( _OptimumEmbeddingBackendFactory, ) +from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode +@component class OptimumDocumentEmbedder: """ A component for computing Document embeddings using models loaded with the HuggingFace Optimum library. @@ -52,6 +54,7 @@ def __init__( suffix: str = "", normalize_embeddings: bool = True, onnx_execution_provider: str = "CPUExecutionProvider", + pooling_mode: Optional[Union[str, PoolingMode]] = None, model_kwargs: Optional[Dict[str, Any]] = None, batch_size: int = 32, progress_bar: bool = True, @@ -68,7 +71,40 @@ def __init__( :param suffix: A string to add to the end of each text. :param normalize_embeddings: Whether to normalize the embeddings to unit length. :param onnx_execution_provider: The execution provider to use for ONNX models. Defaults to - "CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers. + "CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers. + + Note: Using the TensorRT execution provider + TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model + optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime + provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We + recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT + execution provider. The usage is as follows: + ```python + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + onnx_execution_provider="TensorrtExecutionProvider", + model_kwargs={ + "provider_options": { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "tmp/trt_cache", + } + }, + ) + ``` + :param pooling_mode: The pooling mode to use. Defaults to None. When None, pooling mode will be inferred from + the model config. If not found, "mean" pooling will be used. + The supported pooling modes are: + - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text + representations. + - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all + the tokens. + - "mean": Perform Mean Pooling on the output of the embedding model. + - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt + (input_length). + - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904. + - "last_token": Perform Last Token Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization parameters. @@ -84,6 +120,15 @@ def __init__( self.token = token token = token.resolve_value() if token else None + if isinstance(pooling_mode, str): + self.pooling_mode = PoolingMode.from_str(pooling_mode) + # Infer pooling mode from model config if not provided, + if pooling_mode is None: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + # Set default to "mean" if not found in model config and not specified by user + if self.pooling_mode is None: + self.pooling_mode = PoolingMode.MEAN + self.prefix = prefix self.suffix = suffix self.normalize_embeddings = normalize_embeddings @@ -124,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]: suffix=self.suffix, normalize_embeddings=self.normalize_embeddings, onnx_execution_provider=self.onnx_execution_provider, + pooling_mode=self.pooling_mode.value, batch_size=self.batch_size, progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, @@ -143,6 +189,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder": """ Deserialize this component from a dictionary. """ + data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) return default_from_dict(cls, data) @@ -193,6 +240,7 @@ def run(self, documents: List[Document]): embeddings = self.embedding_backend.embed( texts_to_embed=texts_to_embed, normalize_embeddings=self.normalize_embeddings, + pooling_mode=self.pooling_mode, progress_bar=self.progress_bar, batch_size=self.batch_size, ) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py index 1531c85ef..828c86111 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace @@ -6,8 +6,10 @@ from haystack_integrations.components.embedders.backends.optimum_backend import ( _OptimumEmbeddingBackendFactory, ) +from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode +@component class OptimumTextEmbedder: """ A component to embed text using models loaded with the HuggingFace Optimum library. @@ -48,6 +50,7 @@ def __init__( suffix: str = "", normalize_embeddings: bool = True, onnx_execution_provider: str = "CPUExecutionProvider", + pooling_mode: Optional[Union[str, PoolingMode]] = None, model_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -60,7 +63,40 @@ def __init__( :param suffix: A string to add to the end of each text. :param normalize_embeddings: Whether to normalize the embeddings to unit length. :param onnx_execution_provider: The execution provider to use for ONNX models. Defaults to - "CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers. + "CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers. + + Note: Using the TensorRT execution provider + TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model + optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime + provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We + recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT + execution provider. The usage is as follows: + ```python + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + onnx_execution_provider="TensorrtExecutionProvider", + model_kwargs={ + "provider_options": { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "tmp/trt_cache", + } + }, + ) + ``` + :param pooling_mode: The pooling mode to use. Defaults to None. When None, pooling mode will be inferred from + the model config. If not found, "mean" pooling will be used. + The supported pooling modes are: + - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text + representations. + - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all + the tokens. + - "mean": Perform Mean Pooling on the output of the embedding model. + - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt + (input_length). + - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904. + - "last_token": Perform Last Token Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization parameters. @@ -71,6 +107,15 @@ def __init__( self.token = token token = token.resolve_value() if token else None + if isinstance(pooling_mode, str): + self.pooling_mode = PoolingMode.from_str(pooling_mode) + # Infer pooling mode from model config if not provided, + if pooling_mode is None: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + # Set default to "mean" if not found in model config and not specified by user + if self.pooling_mode is None: + self.pooling_mode = PoolingMode.MEAN + self.prefix = prefix self.suffix = suffix self.normalize_embeddings = normalize_embeddings @@ -107,6 +152,7 @@ def to_dict(self) -> Dict[str, Any]: suffix=self.suffix, normalize_embeddings=self.normalize_embeddings, onnx_execution_provider=self.onnx_execution_provider, + pooling_mode=self.pooling_mode.value, model_kwargs=self.model_kwargs, token=self.token.to_dict() if self.token else None, ) @@ -122,13 +168,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder": """ Deserialize this component from a dictionary. """ + data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) return default_from_dict(cls, data) @component.output_types(embedding=List[float]) def run(self, text: str): - """Embed a string. + """ + Embed a string. :param text: The text to embed. :return: The embeddings of the text. @@ -147,7 +195,7 @@ def run(self, text: str): text_to_embed = self.prefix + text + self.suffix embedding = self.embedding_backend.embed( - texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings + texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings, pooling_mode=self.pooling_mode ) return {"embedding": embedding} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py b/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py new file mode 100644 index 000000000..aa011b8ce --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py @@ -0,0 +1,231 @@ +import json +from enum import Enum +from typing import Optional + +import torch +from haystack.utils import Secret +from huggingface_hub import hf_hub_download + + +class PoolingMode(Enum): + """ + Pooling Modes support by the Optimum Embedders. + """ + + CLS = "cls" + MEAN = "mean" + MAX = "max" + MEAN_SQRT_LEN = "mean_sqrt_len" + WEIGHTED_MEAN = "weighted_mean" + LAST_TOKEN = "last_token" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "PoolingMode": + """ + Create a pooling mode from a string. + + :param string: + The string to convert. + :returns: + The pooling mode. + """ + enum_map = {e.value: e for e in PoolingMode} + pooling_mode = enum_map.get(string) + if pooling_mode is None: + msg = f"Unknown Pooling mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return pooling_mode + + +class HFPoolingMode: + """ + Gets the pooling mode of the Sentence Transformer model from the Hugging Face Hub. + """ + + @staticmethod + def get_pooling_mode(model: str, token: Optional[Secret] = None) -> Optional[PoolingMode]: + try: + pooling_config_path = hf_hub_download(repo_id=model, token=token, filename="1_Pooling/config.json") + + with open(pooling_config_path) as f: + pooling_config = json.load(f) + + # Filter only those keys that start with "pooling_mode" and are True + true_pooling_modes = [ + key for key, value in pooling_config.items() if key.startswith("pooling_mode") and value + ] + + pooling_modes_map = { + "pooling_mode_cls_token": PoolingMode.CLS, + "pooling_mode_mean_tokens": PoolingMode.MEAN, + "pooling_mode_max_tokens": PoolingMode.MAX, + "pooling_mode_mean_sqrt_len_tokens": PoolingMode.MEAN_SQRT_LEN, + "pooling_mode_weightedmean_tokens": PoolingMode.WEIGHTED_MEAN, + "pooling_mode_last_token": PoolingMode.LAST_TOKEN, + } + + # If exactly one True pooling mode is found, return it + if len(true_pooling_modes) == 1: + pooling_mode_from_config = true_pooling_modes[0] + pooling_mode = pooling_modes_map.get(pooling_mode_from_config) + # If no True pooling modes or more than one True pooling mode is found, return None + else: + pooling_mode = None + return pooling_mode + except Exception: + return None + + +class Pooling: + """ + Class to manage pooling of the embeddings. + + :param pooling_mode: The pooling mode to use. + :param attention_mask: The attention mask of the tokenized text. + :param model_output: The output of the embedding model. + """ + + def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.tensor, model_output: torch.tensor): + self.pooling_mode = pooling_mode + self.attention_mask = attention_mask + self.model_output = model_output + + def _cls_pooling(self, token_embeddings: torch.tensor) -> torch.tensor: + """ + Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text + representations. + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + embeddings = token_embeddings[:, 0] + return embeddings + + def _max_pooling(self, token_embeddings: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: + """ + Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all tokens. + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + # Set padding tokens to large negative value + token_embeddings[input_mask_expanded == 0] = -1e9 + embeddings = torch.max(token_embeddings, 1)[0] + return embeddings + + def _mean_pooling(self, token_embeddings: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: + """ + Perform Mean Pooling on the output of the embedding model. + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / sum_mask + + def _mean_sqrt_len_pooling(self, token_embeddings: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: + """ + Perform mean-pooling on the output of the embedding model, but divide by sqrt(input_length). + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / torch.sqrt(sum_mask) + + def _weighted_mean_pooling(self, token_embeddings: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: + """ + Perform Weighted (position) Mean Pooling on the output of the embedding model. + See https://arxiv.org/abs/2202.08904. + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + # token_embeddings shape: bs, seq, hidden_dim + weights = ( + torch.arange(start=1, end=token_embeddings.shape[1] + 1) + .unsqueeze(0) + .unsqueeze(-1) + .expand(token_embeddings.size()) + .float() + .to(token_embeddings.device) + ) + input_mask_expanded = input_mask_expanded * weights + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / sum_mask + + def _last_token_pooling(self, token_embeddings: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: + """ + Perform Last Token Pooling on the output of the embedding model. See https://arxiv.org/abs/2202.08904 & + https://arxiv.org/abs/2201.10005. + + :param model_output: The output of the embedding model. + :param attention_mask: The attention mask of the tokenized text. + :return: The embeddings of the text after mean pooling. + """ + bs, seq_len, hidden_dim = token_embeddings.shape + # attention_mask shape: (bs, seq_len) + # Get shape [bs] indices of the last token (i.e. the last token for each batch item) + # argmin gives us the index of the first 0 in the attention mask; We get the last 1 index by subtracting 1 + # Any sequence where min == 1, we use the entire sequence length since argmin = 0 + values, indices = torch.min(attention_mask, 1, keepdim=False) + gather_indices = torch.where(values == 0, indices, seq_len) - 1 # Shape [bs] + # There are empty sequences, where the index would become -1 which will crash + gather_indices = torch.clamp(gather_indices, min=0) + + # Turn indices from shape [bs] --> [bs, 1, hidden_dim] + gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) + gather_indices = gather_indices.unsqueeze(1) + + # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim) + # Actually no need for the attention mask as we gather the last token where attn_mask = 1 + # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we + # use the attention mask to ignore them again + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + embeddings = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) + return embeddings + + def pool_embeddings(self) -> torch.tensor: + """ + Perform pooling on the output of the embedding model. + + :param pooling_mode: The pooling mode to use. + :param attention_mask: The attention mask of the tokenized text. + :param model_output: The output of the embedding model. + :return: The embeddings of the text after pooling. + """ + pooling_func_map = { + PoolingMode.CLS: self._cls_pooling, + PoolingMode.MEAN: self._mean_pooling, + PoolingMode.MAX: self._max_pooling, + PoolingMode.MEAN_SQRT_LEN: self._mean_sqrt_len_pooling, + PoolingMode.WEIGHTED_MEAN: self._weighted_mean_pooling, + PoolingMode.LAST_TOKEN: self._last_token_pooling, + } + self._pooling_function = pooling_func_map[self.pooling_mode] + + # First element of model_output contains all token embeddings + token_embeddings = self.model_output[0] + + embeddings = ( + self._pooling_function(token_embeddings, self.attention_mask) # type: ignore + if self._pooling_function != self._cls_pooling + else self._pooling_function(token_embeddings) # type: ignore + ) + + return embeddings diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index b7d285196..612a9ab0e 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -4,6 +4,7 @@ from haystack.dataclasses import Document from haystack.utils.auth import Secret from haystack_integrations.components.embedders import OptimumDocumentEmbedder +from haystack_integrations.components.embedders.pooling import PoolingMode from huggingface_hub.utils import RepositoryNotFoundError @@ -16,8 +17,17 @@ def mock_check_valid_model(): yield mock +@pytest.fixture +def mock_get_pooling_mode(): + with patch( + "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + MagicMock(return_value=PoolingMode.MEAN), + ) as mock: + yield mock + + class TestOptimumDocumentEmbedder: - def test_init_default(self, monkeypatch, mock_check_valid_model): # noqa: ARG002 + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumDocumentEmbedder() @@ -27,6 +37,7 @@ def test_init_default(self, monkeypatch, mock_check_valid_model): # noqa: ARG00 assert embedder.suffix == "" assert embedder.normalize_embeddings is True assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] @@ -48,6 +59,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 meta_fields_to_embed=["test_field"], embedding_separator=" | ", normalize_embeddings=False, + pooling_mode="max", onnx_execution_provider="CUDAExecutionProvider", model_kwargs={"trust_remote_code": True}, ) @@ -62,6 +74,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder.embedding_separator == " | " assert embedder.normalize_embeddings is False assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX assert embedder.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", @@ -69,12 +82,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "use_auth_token": "fake-api-token", } - def test_initialize_with_invalid_model(self, mock_check_valid_model): - mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") - with pytest.raises(RepositoryNotFoundError): - OptimumDocumentEmbedder(model="invalid_model_id") - - def test_to_dict(self, mock_check_valid_model): # noqa: ARG002 + def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumDocumentEmbedder() data = component.to_dict() @@ -91,6 +99,7 @@ def test_to_dict(self, mock_check_valid_model): # noqa: ARG002 "embedding_separator": "\n", "normalize_embeddings": True, "onnx_execution_provider": "CPUExecutionProvider", + "pooling_mode": "mean", "model_kwargs": { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", @@ -99,7 +108,7 @@ def test_to_dict(self, mock_check_valid_model): # noqa: ARG002 }, } - def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 + def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumDocumentEmbedder( model="sentence-transformers/all-minilm-l6-v2", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -111,6 +120,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n embedding_separator=" | ", normalize_embeddings=False, onnx_execution_provider="CUDAExecutionProvider", + pooling_mode="max", model_kwargs={"trust_remote_code": True}, ) data = component.to_dict() @@ -128,6 +138,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n "embedding_separator": " | ", "normalize_embeddings": False, "onnx_execution_provider": "CUDAExecutionProvider", + "pooling_mode": "max", "model_kwargs": { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", @@ -137,6 +148,52 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n }, } + def test_initialize_with_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + OptimumDocumentEmbedder(model="invalid_model_id") + + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 + mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") + with pytest.raises(ValueError): + OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode" + ) + + def test_infer_pooling_mode_from_str(self): + """ + Test that the pooling mode is correctly inferred from a string. + The pooling mode is "mean" as per the model config. + """ + for pooling_mode in PoolingMode: + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=pooling_mode.value, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == pooling_mode + + @pytest.mark.integration + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumDocumentEmbedder( + model="embedding_model_finetuned", + pooling_mode=None, + ) + + assert embedder.model == "embedding_model_finetuned" + assert embedder.pooling_mode == PoolingMode.MEAN + + @pytest.mark.integration + def test_infer_pooling_mode_from_hf(self): + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=None, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == PoolingMode.MEAN + def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): # noqa: ARG002 documents = [ Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) @@ -146,6 +203,7 @@ def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): # noq model="sentence-transformers/all-minilm-l6-v2", meta_fields_to_embed=["meta_field"], embedding_separator=" | ", + pooling_mode="mean", ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -165,6 +223,7 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): # noqa: model="sentence-transformers/all-minilm-l6-v2", prefix="my_prefix ", suffix=" my_suffix", + pooling_mode="mean", ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -178,9 +237,7 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): # noqa: ] def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 - embedder = OptimumDocumentEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - ) + embedder = OptimumDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="mean") embedder.warm_up() # wrong formats string_input = "text" diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index 2f5b60b45..52124944c 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -3,6 +3,7 @@ import pytest from haystack.utils.auth import Secret from haystack_integrations.components.embedders import OptimumTextEmbedder +from haystack_integrations.components.embedders.pooling import PoolingMode from huggingface_hub.utils import RepositoryNotFoundError @@ -15,8 +16,17 @@ def mock_check_valid_model(): yield mock +@pytest.fixture +def mock_get_pooling_mode(): + with patch( + "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + MagicMock(return_value=PoolingMode.MEAN), + ) as mock: + yield mock + + class TestOptimumTextEmbedder: - def test_init_default(self, monkeypatch, mock_check_valid_model): # noqa: ARG002 + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumTextEmbedder() @@ -26,6 +36,7 @@ def test_init_default(self, monkeypatch, mock_check_valid_model): # noqa: ARG00 assert embedder.suffix == "" assert embedder.normalize_embeddings is True assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN assert embedder.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", @@ -39,6 +50,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 prefix="prefix", suffix="suffix", normalize_embeddings=False, + pooling_mode="max", onnx_execution_provider="CUDAExecutionProvider", model_kwargs={"trust_remote_code": True}, ) @@ -49,6 +61,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder.suffix == "suffix" assert embedder.normalize_embeddings is False assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX assert embedder.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", @@ -56,12 +69,7 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "use_auth_token": "fake-api-token", } - def test_initialize_with_invalid_model(self, mock_check_valid_model): - mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") - with pytest.raises(RepositoryNotFoundError): - OptimumTextEmbedder(model="invalid_model_id") - - def test_to_dict(self, mock_check_valid_model): # noqa: ARG002 + def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumTextEmbedder() data = component.to_dict() @@ -74,6 +82,7 @@ def test_to_dict(self, mock_check_valid_model): # noqa: ARG002 "suffix": "", "normalize_embeddings": True, "onnx_execution_provider": "CPUExecutionProvider", + "pooling_mode": "mean", "model_kwargs": { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", @@ -90,6 +99,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n suffix="suffix", normalize_embeddings=False, onnx_execution_provider="CUDAExecutionProvider", + pooling_mode="max", model_kwargs={"trust_remote_code": True}, ) data = component.to_dict() @@ -103,6 +113,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n "suffix": "suffix", "normalize_embeddings": False, "onnx_execution_provider": "CUDAExecutionProvider", + "pooling_mode": "max", "model_kwargs": { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", @@ -112,10 +123,55 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n }, } + def test_initialize_with_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + OptimumTextEmbedder(model="invalid_model_id", pooling_mode="max") + + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 + mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") + with pytest.raises(ValueError): + OptimumTextEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode") + + def test_infer_pooling_mode_from_str(self): + """ + Test that the pooling mode is correctly inferred from a string. + The pooling mode is "mean" as per the model config. + """ + for pooling_mode in PoolingMode: + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=pooling_mode.value, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == pooling_mode + + @pytest.mark.integration + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumTextEmbedder( + model="embedding_model_finetuned", + pooling_mode=None, + ) + + assert embedder.model == "embedding_model_finetuned" + assert embedder.pooling_mode == PoolingMode.MEAN + + @pytest.mark.integration + def test_infer_pooling_mode_from_hf(self): + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=None, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == PoolingMode.MEAN + def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 embedder = OptimumTextEmbedder( model="sentence-transformers/all-mpnet-base-v2", token=Secret.from_token("fake-api-token"), + pooling_mode="mean", ) embedder.warm_up()