diff --git a/.github/workflows/optimum.yml b/.github/workflows/optimum.yml index 3b0d137da..f5f59ec89 100644 --- a/.github/workflows/optimum.yml +++ b/.github/workflows/optimum.yml @@ -52,9 +52,9 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + # - name: Generate docs + # if: matrix.python-version == '3.9' && runner.os == 'Linux' + # run: hatch run docs - name: Run tests run: hatch run cov diff --git a/integrations/optimum/pydoc/config.yml b/integrations/optimum/pydoc/config.yml index 617eb4aed..996678c55 100644 --- a/integrations/optimum/pydoc/config.yml +++ b/integrations/optimum/pydoc/config.yml @@ -6,6 +6,8 @@ loaders: "haystack_integrations.components.embedders.optimum.optimum_document_embedder", "haystack_integrations.components.embedders.optimum.optimum_text_embedder", "haystack_integrations.components.embedders.optimum.pooling", + "haystack_integrations.components.embedders.optimum.optimization", + "haystack_integrations.components.embedders.optimum.quantization", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py index e2ab2d6b7..02e56b34c 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py @@ -2,8 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 +from .optimization import OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode from .optimum_document_embedder import OptimumDocumentEmbedder from .optimum_text_embedder import OptimumTextEmbedder from .pooling import OptimumEmbedderPooling +from .quantization import OptimumEmbedderQuantizationConfig, OptimumEmbedderQuantizationMode -__all__ = ["OptimumDocumentEmbedder", "OptimumEmbedderPooling", "OptimumTextEmbedder"] +__all__ = [ + "OptimumDocumentEmbedder", + "OptimumEmbedderOptimizationMode", + "OptimumEmbedderOptimizationConfig", + "OptimumEmbedderPooling", + "OptimumEmbedderQuantizationMode", + "OptimumEmbedderQuantizationConfig", + "OptimumTextEmbedder", +] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py index fc4f0b1ae..a6d226ecc 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py @@ -1,7 +1,8 @@ import copy import json from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -11,10 +12,17 @@ from sentence_transformers.models import Pooling as SentenceTransformerPoolingLayer from tqdm import tqdm from transformers import AutoTokenizer +from transformers.modeling_outputs import BaseModelOutput -from optimum.onnxruntime import ORTModelForFeatureExtraction +from optimum.onnxruntime import ( + ORTModelForFeatureExtraction, + ORTOptimizer, + ORTQuantizer, +) +from .optimization import OptimumEmbedderOptimizationConfig from .pooling import OptimumEmbedderPooling +from .quantization import OptimumEmbedderQuantizationConfig @dataclass @@ -29,16 +37,29 @@ class _EmbedderParams: progress_bar: bool pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] model_kwargs: Optional[Dict[str, Any]] + working_dir: Optional[str] + optimizer_settings: Optional[OptimumEmbedderOptimizationConfig] + quantizer_settings: Optional[OptimumEmbedderQuantizationConfig] def serialize(self) -> Dict[str, Any]: out = {} for field in self.__dataclass_fields__.keys(): + if field in [ + "pooling_mode", + "token", + "optimizer_settings", + "quantizer_settings", + ]: + continue out[field] = copy.deepcopy(getattr(self, field)) # Fixups. assert isinstance(self.pooling_mode, OptimumEmbedderPooling) - out["pooling_mode"] = self.pooling_mode.value + out["pooling_mode"] = str(self.pooling_mode) out["token"] = self.token.to_dict() if self.token else None + out["optimizer_settings"] = self.optimizer_settings.to_dict() if self.optimizer_settings else None + out["quantizer_settings"] = self.quantizer_settings.to_dict() if self.quantizer_settings else None + out["model_kwargs"].pop("use_auth_token", None) serialize_hf_model_kwargs(out["model_kwargs"]) return out @@ -46,6 +67,11 @@ def serialize(self) -> Dict[str, Any]: @classmethod def deserialize_inplace(cls, data: Dict[str, Any]) -> Dict[str, Any]: data["pooling_mode"] = OptimumEmbedderPooling.from_str(data["pooling_mode"]) + if data["optimizer_settings"] is not None: + data["optimizer_settings"] = OptimumEmbedderOptimizationConfig.from_dict(data["optimizer_settings"]) + if data["quantizer_settings"] is not None: + data["quantizer_settings"] = OptimumEmbedderQuantizationConfig.from_dict(data["quantizer_settings"]) + deserialize_secrets_inplace(data, keys=["token"]) deserialize_hf_model_kwargs(data["model_kwargs"]) return data @@ -71,6 +97,11 @@ def __init__(self, params: _EmbedderParams): params.model_kwargs = params.model_kwargs or {} + if params.optimizer_settings or params.quantizer_settings: + if not params.working_dir: + msg = "Working directory is required for optimization and quantization" + raise ValueError(msg) + # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters params.model_kwargs.setdefault("model_id", params.model) params.model_kwargs.setdefault("provider", params.onnx_execution_provider) @@ -82,18 +113,48 @@ def __init__(self, params: _EmbedderParams): self.pooling_layer = None def warm_up(self): - self.model = ORTModelForFeatureExtraction.from_pretrained(**self.params.model_kwargs, export=True) + assert self.params.model_kwargs + model_kwargs = copy.deepcopy(self.params.model_kwargs) + model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True) + + # Model ID will be passed explicitly if optimization/quantization is enabled. + model_kwargs.pop("model_id", None) + + optimized_model = False + if self.params.optimizer_settings: + assert self.params.working_dir + optimizer = ORTOptimizer.from_pretrained(model) + save_dir = optimizer.optimize( + save_dir=self.params.working_dir, optimization_config=self.params.optimizer_settings.to_optimum_config() + ) + model = ORTModelForFeatureExtraction.from_pretrained(model_id=save_dir, **model_kwargs) + optimized_model = True + + if self.params.quantizer_settings: + assert self.params.working_dir + + # We need to create a subfolder for models that were optimized before quantization + # since Optimum expects no more than one ONXX model in the working directory. There's + # a file name parameter, but the optimizer only returns the working directory. + working_dir = ( + Path(self.params.working_dir) if not optimized_model else Path(self.params.working_dir) / "quantized" + ) + quantizer = ORTQuantizer.from_pretrained(model) + save_dir = quantizer.quantize( + save_dir=working_dir, quantization_config=self.params.quantizer_settings.to_optimum_config() + ) + model = ORTModelForFeatureExtraction.from_pretrained(model_id=save_dir, **model_kwargs) + + self.model = model self.tokenizer = AutoTokenizer.from_pretrained( self.params.model, token=self.params.token.resolve_value() if self.params.token else None ) # We need the width of the embeddings to initialize the pooling layer # so we do a dummy forward pass with the model. - dummy_input = self.tokenizer(["dummy input"], padding=True, truncation=True, return_tensors="pt").to( - self.model.device - ) - dummy_output = self.model(input_ids=dummy_input["input_ids"], attention_mask=dummy_input["attention_mask"]) - width = dummy_output[0].size(dim=2) # BaseModelOutput.last_hidden_state + width = self._tokenize_and_generate_outputs(["dummy input"])[1][0].size( + dim=2 + ) # BaseModelOutput.last_hidden_state self.pooling_layer = SentenceTransformerPoolingLayer( width, @@ -105,6 +166,17 @@ def warm_up(self): pooling_mode_lasttoken=self.params.pooling_mode == OptimumEmbedderPooling.LAST_TOKEN, ) + def _tokenize_and_generate_outputs(self, texts: List[str]) -> Tuple[Dict[str, Any], BaseModelOutput]: + assert self.model is not None + assert self.tokenizer is not None + + tokenizer_outputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( + self.model.device + ) + model_inputs = {k: v for k, v in tokenizer_outputs.items() if k in self.model.inputs_names} + model_outputs = self.model(**model_inputs) + return tokenizer_outputs, model_outputs + @property def parameters(self) -> _EmbedderParams: return self.params @@ -140,11 +212,8 @@ def embed_texts( desc="Calculating embeddings", ): batch = sentences_sorted[i : i + self.params.batch_size] - encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device) - model_output = self.model( - input_ids=encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"] - ) - sentence_embeddings = self.pool_embeddings(model_output[0], encoded_input["attention_mask"].to(device)) + tokenizer_output, model_output = self._tokenize_and_generate_outputs(batch) + sentence_embeddings = self.pool_embeddings(model_output[0], tokenizer_output["attention_mask"].to(device)) all_embeddings.append(sentence_embeddings) embeddings = torch.cat(all_embeddings, dim=0) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimization.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimization.py new file mode 100644 index 000000000..5a4447570 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimization.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict + +from optimum.onnxruntime.configuration import AutoOptimizationConfig, OptimizationConfig + + +class OptimumEmbedderOptimizationMode(Enum): + """ + [ONXX Optimization Modes](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization.html) + support by the Optimum Embedders. + """ + + #: Basic general optimizations. + O1 = "o1" + + #: Basic and extended general optimizations, transformers-specific fusions. + O2 = "o2" + + #: Same as O2 with Gelu approximation. + O3 = "o3" + + #: Same as O3 with mixed precision. + O4 = "o4" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "OptimumEmbedderOptimizationMode": + """ + Create an optimization mode from a string. + + :param string: + String to convert. + :returns: + Optimization mode. + """ + enum_map = {e.value: e for e in OptimumEmbedderOptimizationMode} + opt_mode = enum_map.get(string) + if opt_mode is None: + msg = f"Unknown optimization mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return opt_mode + + +@dataclass(frozen=True) +class OptimumEmbedderOptimizationConfig: + """ + Configuration for Optimum Embedder Optimization. + + :param mode: + Optimization mode. + :param for_gpu: + Whether to optimize for GPUs. + """ + + mode: OptimumEmbedderOptimizationMode + for_gpu: bool = True + + def to_optimum_config(self) -> OptimizationConfig: + """ + Convert the configuration to a Optimum configuration. + + :returns: + Optimum configuration. + """ + if self.mode == OptimumEmbedderOptimizationMode.O1: + return AutoOptimizationConfig.O1(for_gpu=self.for_gpu) + elif self.mode == OptimumEmbedderOptimizationMode.O2: + return AutoOptimizationConfig.O2(for_gpu=self.for_gpu) + elif self.mode == OptimumEmbedderOptimizationMode.O3: + return AutoOptimizationConfig.O3(for_gpu=self.for_gpu) + elif self.mode == OptimumEmbedderOptimizationMode.O4: + return AutoOptimizationConfig.O4(for_gpu=self.for_gpu) + else: + msg = f"Unknown optimization mode '{self.mode}'" + raise ValueError(msg) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the configuration to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return { + "mode": str(self.mode), + "for_gpu": self.for_gpu, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OptimumEmbedderOptimizationConfig": + """ + Create an optimization configuration from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Optimization configuration. + """ + return OptimumEmbedderOptimizationConfig( + mode=OptimumEmbedderOptimizationMode.from_str(data["mode"]), + for_gpu=data["for_gpu"], + ) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index 2f49bd0b3..a6db47090 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -4,14 +4,17 @@ from haystack.utils import Secret from ._backend import _EmbedderBackend, _EmbedderParams +from .optimization import OptimumEmbedderOptimizationConfig from .pooling import OptimumEmbedderPooling +from .quantization import OptimumEmbedderQuantizationConfig @component class OptimumDocumentEmbedder: """ - A component for computing Document embeddings using models loaded with the HuggingFace Optimum library. - This component is designed to seamlessly inference models using the high speed ONNX runtime. + A component for computing `Document` embeddings using models loaded with the + [HuggingFace Optimum](https://huggingface.co/docs/optimum/index) library, + leveraging the ONNX runtime for high-speed inference. The embedding of each Document is stored in the `embedding` field of the Document. @@ -30,18 +33,6 @@ class OptimumDocumentEmbedder: # [0.017020374536514282, -0.023255806416273117, ...] ``` - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with any embedding model present on the Hugging Face - Hub. - - **Conversion to ONNX**: The models are converted to ONNX using the HuggingFace Optimum library. This is - performed in real-time, during the warm-up step. - - **Accelerated Inference on GPU**: Supports using different execution providers such as CUDA and TensorRT, to - accelerate ONNX Runtime inference on GPUs. - Simply pass the execution provider as the onnx_execution_provider parameter. Additonal parameters can be passed - to the model using the model_kwargs parameter. - For more details refer to the HuggingFace documentation: - https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu. """ def __init__( @@ -54,6 +45,9 @@ def __init__( onnx_execution_provider: str = "CPUExecutionProvider", pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] = None, model_kwargs: Optional[Dict[str, Any]] = None, + working_dir: Optional[str] = None, + optimizer_settings: Optional[OptimumEmbedderOptimizationConfig] = None, + quantizer_settings: Optional[OptimumEmbedderQuantizationConfig] = None, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -62,13 +56,19 @@ def __init__( """ Create a OptimumDocumentEmbedder component. - :param model: A string representing the model id on HF Hub. - :param token: The HuggingFace token to use as HTTP bearer authorization. - :param prefix: A string to add to the beginning of each text. - :param suffix: A string to add to the end of each text. - :param normalize_embeddings: Whether to normalize the embeddings to unit length. - :param onnx_execution_provider: The execution provider to use for ONNX models. See - https://onnxruntime.ai/docs/execution-providers/ for possible providers. + :param model: + A string representing the model id on HF Hub. + :param token: + The HuggingFace token to use as HTTP bearer authorization. + :param prefix: + A string to add to the beginning of each text. + :param suffix: + A string to add to the end of each text. + :param normalize_embeddings: + Whether to normalize the embeddings to unit length. + :param onnx_execution_provider: + The [execution provider](https://onnxruntime.ai/docs/execution-providers/) + to use for ONNX models. Note: Using the TensorRT execution provider TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model @@ -88,16 +88,31 @@ def __init__( }, ) ``` - :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. - Refer to the OptimumEmbedderPooling enum for supported pooling modes. - :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. - In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization - parameters. - :param batch_size: Number of Documents to encode at once. - :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments - to keep the logs clean. - :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. - :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param pooling_mode: + The pooling mode to use. When `None`, pooling mode will be inferred from the model config. + :param model_kwargs: + Dictionary containing additional keyword arguments to pass to the model. + In case of duplication, these kwargs override `model`, `onnx_execution_provider` + and `token` initialization parameters. + :param working_dir: + The directory to use for storing intermediate files + generated during model optimization/quantization. + + Required for optimization and quantization. + :param optimizer_settings: + Configuration for Optimum Embedder Optimization. + If `None`, no additional optimization is be applied. + :param quantizer_settings: + Configuration for Optimum Embedder Quantization. + If `None`, no quantization is be applied. + :param batch_size: + Number of Documents to encode at once. + :param progress_bar: + Whether to show a progress bar or not. + :param meta_fields_to_embed: + List of meta fields that should be embedded along with the Document text. + :param embedding_separator: + Separator used to concatenate the meta fields to the Document text. """ params = _EmbedderParams( model=model, @@ -110,6 +125,9 @@ def __init__( progress_bar=progress_bar, pooling_mode=pooling_mode, model_kwargs=model_kwargs, + working_dir=working_dir, + optimizer_settings=optimizer_settings, + quantizer_settings=quantizer_settings, ) self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator @@ -119,7 +137,7 @@ def __init__( def warm_up(self): """ - Load the embedding backend. + Initializes the component. """ if self._initialized: return @@ -129,7 +147,10 @@ def warm_up(self): def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ init_params = self._backend.parameters.serialize() init_params["meta_fields_to_embed"] = self.meta_fields_to_embed @@ -139,7 +160,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder": """ - Deserialize this component from a dictionary. + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. """ _EmbedderParams.deserialize_inplace(data["init_parameters"]) return default_from_dict(cls, data) @@ -169,8 +195,10 @@ def run(self, documents: List[Document]): Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. - :param documents: A list of Documents to embed. - :return: A dictionary containing the updated Documents with their embeddings. + :param documents: + A list of Documents to embed. + :returns: + The updated Documents with their embeddings. """ if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py index 64454bf9f..394ea04ad 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py @@ -4,14 +4,17 @@ from haystack.utils import Secret from ._backend import _EmbedderBackend, _EmbedderParams +from .optimization import OptimumEmbedderOptimizationConfig from .pooling import OptimumEmbedderPooling +from .quantization import OptimumEmbedderQuantizationConfig @component class OptimumTextEmbedder: """ - A component to embed text using models loaded with the HuggingFace Optimum library. - This component is designed to seamlessly inference models using the high speed ONNX runtime. + A component to embed text using models loaded with the + [HuggingFace Optimum](https://huggingface.co/docs/optimum/index) library, + leveraging the ONNX runtime for high-speed inference. Usage example: ```python @@ -26,18 +29,6 @@ class OptimumTextEmbedder: # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]} ``` - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with any embedding model present on the Hugging Face - Hub. - - **Conversion to ONNX**: The models are converted to ONNX using the HuggingFace Optimum library. This is - performed in real-time, during the warm-up step. - - **Accelerated Inference on GPU**: Supports using different execution providers such as CUDA and TensorRT, to - accelerate ONNX Runtime inference on GPUs. - Simply pass the execution provider as the onnx_execution_provider parameter. Additonal parameters can be passed - to the model using the model_kwargs parameter. - For more details refer to the HuggingFace documentation: - https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu. """ def __init__( @@ -50,41 +41,62 @@ def __init__( onnx_execution_provider: str = "CPUExecutionProvider", pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] = None, model_kwargs: Optional[Dict[str, Any]] = None, + working_dir: Optional[str] = None, + optimizer_settings: Optional[OptimumEmbedderOptimizationConfig] = None, + quantizer_settings: Optional[OptimumEmbedderQuantizationConfig] = None, ): """ - Create a OptimumTextEmbedder component. - - :param model: A string representing the model id on HF Hub. - :param token: The HuggingFace token to use as HTTP bearer authorization. - :param prefix: A string to add to the beginning of each text. - :param suffix: A string to add to the end of each text. - :param normalize_embeddings: Whether to normalize the embeddings to unit length. - :param onnx_execution_provider: The execution provider to use for ONNX models. See - https://onnxruntime.ai/docs/execution-providers/ for possible providers. - - Note: Using the TensorRT execution provider - TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model - optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime - provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We - recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT - execution provider. The usage is as follows: - ```python - embedder = OptimumTextEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - onnx_execution_provider="TensorrtExecutionProvider", - model_kwargs={ - "provider_options": { - "trt_engine_cache_enable": True, - "trt_engine_cache_path": "tmp/trt_cache", - } - }, - ) - ``` - :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. - Refer to the OptimumEmbedderPooling enum for supported pooling modes. - :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. - In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization - parameters. + Create a OptimumTextEmbedder component. + + :param model: + A string representing the model id on HF Hub. + :param token: + The HuggingFace token to use as HTTP bearer authorization. + :param prefix: + A string to add to the beginning of each text. + :param suffix: + A string to add to the end of each text. + :param normalize_embeddings: + Whether to normalize the embeddings to unit length. + :param onnx_execution_provider: + The [execution provider](https://onnxruntime.ai/docs/execution-providers/) + to use for ONNX models. + + Note: Using the TensorRT execution provider + TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model + optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime + provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We + recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT + execution provider. The usage is as follows: + ```python + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + onnx_execution_provider="TensorrtExecutionProvider", + model_kwargs={ + "provider_options": { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "tmp/trt_cache", + } + }, + ) + ``` + :param pooling_mode: + The pooling mode to use. When `None`, pooling mode will be inferred from the model config. + :param model_kwargs: + Dictionary containing additional keyword arguments to pass to the model. + In case of duplication, these kwargs override `model`, `onnx_execution_provider` + and `token` initialization parameters. + :param working_dir: + The directory to use for storing intermediate files + generated during model optimization/quantization. + + Required for optimization and quantization. + :param optimizer_settings: + Configuration for Optimum Embedder Optimization. + If `None`, no additional optimization is applied. + :param quantizer_settings: + Configuration for Optimum Embedder Quantization. + If `None`, no quantization is applied. """ params = _EmbedderParams( model=model, @@ -97,13 +109,16 @@ def __init__( progress_bar=False, pooling_mode=pooling_mode, model_kwargs=model_kwargs, + working_dir=working_dir, + optimizer_settings=optimizer_settings, + quantizer_settings=quantizer_settings, ) self._backend = _EmbedderBackend(params) self._initialized = False def warm_up(self): """ - Load the embedding backend. + Initializes the component. """ if self._initialized: return @@ -113,7 +128,10 @@ def warm_up(self): def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ init_params = self._backend.parameters.serialize() # Remove init params that are not provided to the text embedder. @@ -124,7 +142,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder": """ - Deserialize this component from a dictionary. + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. """ _EmbedderParams.deserialize_inplace(data["init_parameters"]) return default_from_dict(cls, data) @@ -134,8 +157,10 @@ def run(self, text: str): """ Embed a string. - :param text: The text to embed. - :return: The embeddings of the text. + :param text: + The text to embed. + :returns: + The embeddings of the text. """ if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py index c4d195b8e..41aa24d64 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py @@ -22,11 +22,10 @@ class OptimumEmbedderPooling(Enum): MEAN_SQRT_LEN = "mean_sqrt_len" #: Perform weighted (position) mean pooling on the output of the - #: embedding model. See https://arxiv.org/abs/2202.08904. + #: embedding model. WEIGHTED_MEAN = "weighted_mean" #: Perform Last Token Pooling on the output of the embedding model. - #: See https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. LAST_TOKEN = "last_token" def __str__(self): @@ -38,9 +37,9 @@ def from_str(cls, string: str) -> "OptimumEmbedderPooling": Create a pooling mode from a string. :param string: - The string to convert. + String to convert. :returns: - The pooling mode. + Pooling mode. """ enum_map = {e.value: e for e in OptimumEmbedderPooling} pooling_mode = enum_map.get(string) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py new file mode 100644 index 000000000..2e68081b5 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/quantization.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict + +from optimum.onnxruntime.configuration import AutoQuantizationConfig, QuantizationConfig + + +class OptimumEmbedderQuantizationMode(Enum): + """ + [Dynamic Quantization Modes](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/quantization) + support by the Optimum Embedders. + """ + + #: Quantization for the ARM64 architecture. + ARM64 = "arm64" + + #: Quantization with AVX-2 instructions. + AVX2 = "avx2" + + #: Quantization with AVX-512 instructions. + AVX512 = "avx512" + + #: Quantization with AVX-512 and VNNI instructions. + AVX512_VNNI = "avx512_vnni" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "OptimumEmbedderQuantizationMode": + """ + Create an quantization mode from a string. + + :param string: + String to convert. + :returns: + Quantization mode. + """ + enum_map = {e.value: e for e in OptimumEmbedderQuantizationMode} + q_mode = enum_map.get(string) + if q_mode is None: + msg = f"Unknown quantization mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return q_mode + + +@dataclass(frozen=True) +class OptimumEmbedderQuantizationConfig: + """ + Configuration for Optimum Embedder Quantization. + + :param mode: + Quantization mode. + :param per_channel: + Whether to apply per-channel quantization. + """ + + mode: OptimumEmbedderQuantizationMode + per_channel: bool = False + + def to_optimum_config(self) -> QuantizationConfig: + """ + Convert the configuration to a Optimum configuration. + + :returns: + Optimum configuration. + """ + if self.mode == OptimumEmbedderQuantizationMode.ARM64: + return AutoQuantizationConfig.arm64(is_static=False, per_channel=self.per_channel) + elif self.mode == OptimumEmbedderQuantizationMode.AVX2: + return AutoQuantizationConfig.avx2(is_static=False, per_channel=self.per_channel) + elif self.mode == OptimumEmbedderQuantizationMode.AVX512: + return AutoQuantizationConfig.avx512(is_static=False, per_channel=self.per_channel) + elif self.mode == OptimumEmbedderQuantizationMode.AVX512_VNNI: + return AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=self.per_channel) + else: + msg = f"Unknown quantization mode '{self.mode}'" + raise ValueError(msg) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the configuration to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return { + "mode": str(self.mode), + "per_channel": self.per_channel, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OptimumEmbedderQuantizationConfig": + """ + Create a configuration from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Quantization configuration. + """ + return OptimumEmbedderQuantizationConfig( + mode=OptimumEmbedderQuantizationMode.from_str(data["mode"]), + per_channel=data["per_channel"], + ) diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index bcbccd533..9288bb688 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -1,12 +1,21 @@ from unittest.mock import MagicMock, patch +import tempfile +import copy import pytest from haystack.dataclasses import Document from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumDocumentEmbedder from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling +from haystack_integrations.components.embedders.optimum.optimization import ( + OptimumEmbedderOptimizationConfig, + OptimumEmbedderOptimizationMode, +) +from haystack_integrations.components.embedders.optimum.quantization import ( + OptimumEmbedderQuantizationConfig, + OptimumEmbedderQuantizationMode, +) from huggingface_hub.utils import RepositoryNotFoundError -import copy @pytest.fixture @@ -63,6 +72,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 pooling_mode="max", onnx_execution_provider="CUDAExecutionProvider", model_kwargs={"trust_remote_code": True}, + working_dir="working_dir", + optimizer_settings=None, + quantizer_settings=None, ) assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" @@ -82,6 +94,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "provider": "CUDAExecutionProvider", "use_auth_token": "fake-api-token", } + assert embedder._backend.parameters.working_dir == "working_dir" + assert embedder._backend.parameters.optimizer_settings is None + assert embedder._backend.parameters.quantizer_settings is None def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumDocumentEmbedder() @@ -105,6 +120,9 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", }, + "working_dir": None, + "optimizer_settings": None, + "quantizer_settings": None, }, } @@ -125,6 +143,9 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): "provider": "CPUExecutionProvider", "use_auth_token": None, } + assert embedder._backend.parameters.working_dir is None + assert embedder._backend.parameters.optimizer_settings is None + assert embedder._backend.parameters.quantizer_settings is None def test_to_and_from_dict_with_custom_init_parameters( self, mock_check_valid_model, mock_get_pooling_mode @@ -142,6 +163,11 @@ def test_to_and_from_dict_with_custom_init_parameters( onnx_execution_provider="CUDAExecutionProvider", pooling_mode="max", model_kwargs={"trust_remote_code": True}, + working_dir="working_dir", + optimizer_settings=OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=True), + quantizer_settings=OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True + ), ) data = component.to_dict() @@ -164,6 +190,9 @@ def test_to_and_from_dict_with_custom_init_parameters( "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", }, + "working_dir": "working_dir", + "optimizer_settings": {"mode": "o1", "for_gpu": True}, + "quantizer_settings": {"mode": "arm64", "per_channel": True}, }, } @@ -185,6 +214,13 @@ def test_to_and_from_dict_with_custom_init_parameters( "provider": "CUDAExecutionProvider", "use_auth_token": None, } + assert embedder._backend.parameters.working_dir == "working_dir" + assert embedder._backend.parameters.optimizer_settings == OptimumEmbedderOptimizationConfig( + OptimumEmbedderOptimizationMode.O1, for_gpu=True + ) + assert embedder._backend.parameters.quantizer_settings == OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True + ) def test_initialize_with_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") @@ -287,7 +323,7 @@ def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 embedder = OptimumDocumentEmbedder( - model="sentence-transformers/all-mpnet-base-v2", + model="sentence-transformers/paraphrase-albert-small-v2", ) embedder.warm_up() empty_list_input = [] @@ -297,7 +333,24 @@ def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 assert not result["documents"] # empty list @pytest.mark.integration - def test_run(self): + @pytest.mark.parametrize( + "opt_config, quant_config", + [ + (None, None), + ( + OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=False), + None, + ), + (None, OptimumEmbedderQuantizationConfig(OptimumEmbedderQuantizationMode.AVX2)), + # onxxruntime 1.17.x breaks support for quantizing optimized models. + # c.f https://discuss.huggingface.co/t/optimize-and-quantize-with-optimum/23675/12 + # ( + # OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O2, for_gpu=False), + # OptimumEmbedderQuantizationConfig(OptimumEmbedderQuantizationMode.AVX2), + # ), + ], + ) + def test_run(self, opt_config, quant_config): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), @@ -305,18 +358,22 @@ def test_run(self): ] docs_copy = copy.deepcopy(docs) - embedder = OptimumDocumentEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - prefix="prefix ", - suffix=" suffix", - meta_fields_to_embed=["topic"], - embedding_separator=" | ", - batch_size=1, - ) - embedder.warm_up() + with tempfile.TemporaryDirectory() as tmpdirname: + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/paraphrase-albert-small-v2", + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + working_dir=tmpdirname, + optimizer_settings=opt_config, + quantizer_settings=quant_config, + ) + embedder.warm_up() - result = embedder.run(documents=docs) - expected = [embedder.run([d]) for d in docs_copy] + result = embedder.run(documents=docs) + expected = [embedder.run([d]) for d in docs_copy] documents_with_embeddings = result["documents"] diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index ce5bc2ffb..ad0e7d800 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -4,6 +4,14 @@ from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumTextEmbedder from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling +from haystack_integrations.components.embedders.optimum.optimization import ( + OptimumEmbedderOptimizationConfig, + OptimumEmbedderOptimizationMode, +) +from haystack_integrations.components.embedders.optimum.quantization import ( + OptimumEmbedderQuantizationConfig, + OptimumEmbedderQuantizationMode, +) from huggingface_hub.utils import RepositoryNotFoundError @@ -53,6 +61,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 pooling_mode="max", onnx_execution_provider="CUDAExecutionProvider", model_kwargs={"trust_remote_code": True}, + working_dir="working_dir", + optimizer_settings=None, + quantizer_settings=None, ) assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" @@ -68,6 +79,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "provider": "CUDAExecutionProvider", "use_auth_token": "fake-api-token", } + assert embedder._backend.parameters.working_dir == "working_dir" + assert embedder._backend.parameters.optimizer_settings is None + assert embedder._backend.parameters.quantizer_settings is None def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumTextEmbedder() @@ -83,10 +97,13 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): "normalize_embeddings": True, "onnx_execution_provider": "CPUExecutionProvider", "pooling_mode": "mean", + "working_dir": None, "model_kwargs": { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", }, + "optimizer_settings": None, + "quantizer_settings": None, }, } @@ -103,6 +120,9 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): "provider": "CPUExecutionProvider", "use_auth_token": None, } + assert embedder._backend.parameters.working_dir is None + assert embedder._backend.parameters.optimizer_settings is None + assert embedder._backend.parameters.quantizer_settings is None def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 component = OptimumTextEmbedder( @@ -114,6 +134,11 @@ def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_mod onnx_execution_provider="CUDAExecutionProvider", pooling_mode="max", model_kwargs={"trust_remote_code": True}, + working_dir="working_dir", + optimizer_settings=OptimumEmbedderOptimizationConfig(OptimumEmbedderOptimizationMode.O1, for_gpu=True), + quantizer_settings=OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True + ), ) data = component.to_dict() @@ -132,6 +157,9 @@ def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_mod "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", }, + "working_dir": "working_dir", + "optimizer_settings": {"mode": "o1", "for_gpu": True}, + "quantizer_settings": {"mode": "arm64", "per_channel": True}, }, } @@ -149,6 +177,13 @@ def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_mod "provider": "CUDAExecutionProvider", "use_auth_token": None, } + assert embedder._backend.parameters.working_dir == "working_dir" + assert embedder._backend.parameters.optimizer_settings == OptimumEmbedderOptimizationConfig( + OptimumEmbedderOptimizationMode.O1, for_gpu=True + ) + assert embedder._backend.parameters.quantizer_settings == OptimumEmbedderQuantizationConfig( + OptimumEmbedderQuantizationMode.ARM64, per_channel=True + ) def test_initialize_with_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") @@ -194,7 +229,7 @@ def test_infer_pooling_mode_from_hf(self): def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 embedder = OptimumTextEmbedder( - model="sentence-transformers/all-mpnet-base-v2", + model="sentence-transformers/paraphrase-albert-small-v2", token=Secret.from_token("fake-api-token"), pooling_mode="mean", ) @@ -209,7 +244,7 @@ def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 def test_run(self): for pooling_mode in OptimumEmbedderPooling: embedder = OptimumTextEmbedder( - model="sentence-transformers/all-mpnet-base-v2", + model="sentence-transformers/paraphrase-albert-small-v2", prefix="prefix ", suffix=" suffix", pooling_mode=pooling_mode,