Skip to content

Commit

Permalink
Add pooling sub module; Update pyproject.toml with dependency info
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Feb 20, 2024
1 parent fe7fa36 commit 9eb400f
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 41 deletions.
5 changes: 5 additions & 0 deletions integrations/optimum/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ classifiers = [
dependencies = [
"haystack-ai",
"transformers[sentencepiece]",
# The main export function of Optimum into ONNX has hidden dependencies.
# It depends on either "sentence-transformers", "diffusers" or "timm", based
# on which model is loaded from HF Hub.
# Ref: https://github.com/huggingface/optimum/blob/8651c0ca1cccf095458bc80329dec9df4601edb4/optimum/exporters/onnx/__main__.py#L164
# "sentence-transformers" has been added, since most embedding models use it
"sentence-transformers>=2.2.0",
"optimum[onnxruntime]"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from haystack.utils.auth import Secret
from haystack_integrations.components.embedders.pooling import Pooling, PoolingMode
from optimum.onnxruntime import ORTModelForFeatureExtraction
from tqdm import tqdm
from transformers import AutoTokenizer
Expand Down Expand Up @@ -38,25 +39,11 @@ def __init__(self, model: str, token: Optional[Secret] = None, model_kwargs: Opt
self.model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True)
self.tokenizer = AutoTokenizer.from_pretrained(model, token=token)

def mean_pooling(self, model_output: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""
Perform Mean Pooling on the output of the Embedding model.
:param model_output: The output of the embedding model.
:param attention_mask: The attention mask of the tokenized text.
:return: The embeddings of the text after mean pooling.
"""
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

def embed(
self,
texts_to_embed: Union[str, List[str]],
normalize_embeddings: bool,
pooling_mode: PoolingMode = PoolingMode.MEAN,
progress_bar: bool = False,
batch_size: int = 1,
) -> Union[List[List[float]], List[float]]:
Expand All @@ -65,6 +52,7 @@ def embed(
:param texts_to_embed: T
:param normalize_embeddings: Whether to normalize the embeddings to unit length.
:param pooling_mode: The pooling mode to use. Defaults to PoolingMode.MEAN.
:param progress_bar: Whether to show a progress bar or not, defaults to False.
:param batch_size: Batch size to use, defaults to 1.
:return: A single embedding if the input is a single string. A list of embeddings if the input is a list of
Expand Down Expand Up @@ -97,8 +85,13 @@ def embed(
# Compute token embeddings
model_output = self.model(**encoded_input)

# Perform mean pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"].to(device))
# Pool Embeddings
pooling = Pooling(
pooling_mode=pooling_mode,
attention_mask=encoded_input["attention_mask"].to(device),
model_output=model_output,
)
sentence_embeddings = pooling.pool_embeddings()

all_embeddings.extend(sentence_embeddings.tolist())

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack_integrations.components.embedders.backends.optimum_backend import (
_OptimumEmbeddingBackendFactory,
)
from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode


@component
class OptimumDocumentEmbedder:
"""
A component for computing Document embeddings using models loaded with the HuggingFace Optimum library.
Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(
suffix: str = "",
normalize_embeddings: bool = True,
onnx_execution_provider: str = "CPUExecutionProvider",
pooling_mode: Optional[Union[str, PoolingMode]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
batch_size: int = 32,
progress_bar: bool = True,
Expand All @@ -68,7 +71,40 @@ def __init__(
:param suffix: A string to add to the end of each text.
:param normalize_embeddings: Whether to normalize the embeddings to unit length.
:param onnx_execution_provider: The execution provider to use for ONNX models. Defaults to
"CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
"CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
Note: Using the TensorRT execution provider
TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model
optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime
provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We
recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT
execution provider. The usage is as follows:
```python
embedder = OptimumDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
onnx_execution_provider="TensorrtExecutionProvider",
model_kwargs={
"provider_options": {
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "tmp/trt_cache",
}
},
)
```
:param pooling_mode: The pooling mode to use. Defaults to None. When None, pooling mode will be inferred from
the model config. If not found, "mean" pooling will be used.
The supported pooling modes are:
- "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text
representations.
- "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all
the tokens.
- "mean": Perform Mean Pooling on the output of the embedding model.
- "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt
(input_length).
- "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See
https://arxiv.org/abs/2202.08904.
- "last_token": Perform Last Token Pooling on the output of the embedding model. See
https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005.
:param model_kwargs: Dictionary containing additional keyword arguments to pass to the model.
In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization
parameters.
Expand All @@ -84,6 +120,15 @@ def __init__(
self.token = token
token = token.resolve_value() if token else None

if isinstance(pooling_mode, str):
self.pooling_mode = PoolingMode.from_str(pooling_mode)
# Infer pooling mode from model config if not provided,
if pooling_mode is None:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token)
# Set default to "mean" if not found in model config and not specified by user
if self.pooling_mode is None:
self.pooling_mode = PoolingMode.MEAN

self.prefix = prefix
self.suffix = suffix
self.normalize_embeddings = normalize_embeddings
Expand Down Expand Up @@ -124,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
normalize_embeddings=self.normalize_embeddings,
onnx_execution_provider=self.onnx_execution_provider,
pooling_mode=self.pooling_mode.value,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
Expand All @@ -143,6 +189,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder":
"""
Deserialize this component from a dictionary.
"""
data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"])
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"])
return default_from_dict(cls, data)
Expand Down Expand Up @@ -193,6 +240,7 @@ def run(self, documents: List[Document]):
embeddings = self.embedding_backend.embed(
texts_to_embed=texts_to_embed,
normalize_embeddings=self.normalize_embeddings,
pooling_mode=self.pooling_mode,
progress_bar=self.progress_bar,
batch_size=self.batch_size,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack_integrations.components.embedders.backends.optimum_backend import (
_OptimumEmbeddingBackendFactory,
)
from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode


@component
class OptimumTextEmbedder:
"""
A component to embed text using models loaded with the HuggingFace Optimum library.
Expand Down Expand Up @@ -48,6 +50,7 @@ def __init__(
suffix: str = "",
normalize_embeddings: bool = True,
onnx_execution_provider: str = "CPUExecutionProvider",
pooling_mode: Optional[Union[str, PoolingMode]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -60,7 +63,40 @@ def __init__(
:param suffix: A string to add to the end of each text.
:param normalize_embeddings: Whether to normalize the embeddings to unit length.
:param onnx_execution_provider: The execution provider to use for ONNX models. Defaults to
"CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
"CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
Note: Using the TensorRT execution provider
TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model
optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime
provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We
recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT
execution provider. The usage is as follows:
```python
embedder = OptimumTextEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
onnx_execution_provider="TensorrtExecutionProvider",
model_kwargs={
"provider_options": {
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "tmp/trt_cache",
}
},
)
```
:param pooling_mode: The pooling mode to use. Defaults to None. When None, pooling mode will be inferred from
the model config. If not found, "mean" pooling will be used.
The supported pooling modes are:
- "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text
representations.
- "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all
the tokens.
- "mean": Perform Mean Pooling on the output of the embedding model.
- "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt
(input_length).
- "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See
https://arxiv.org/abs/2202.08904.
- "last_token": Perform Last Token Pooling on the output of the embedding model. See
https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005.
:param model_kwargs: Dictionary containing additional keyword arguments to pass to the model.
In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization
parameters.
Expand All @@ -71,6 +107,15 @@ def __init__(
self.token = token
token = token.resolve_value() if token else None

if isinstance(pooling_mode, str):
self.pooling_mode = PoolingMode.from_str(pooling_mode)
# Infer pooling mode from model config if not provided,
if pooling_mode is None:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token)
# Set default to "mean" if not found in model config and not specified by user
if self.pooling_mode is None:
self.pooling_mode = PoolingMode.MEAN

self.prefix = prefix
self.suffix = suffix
self.normalize_embeddings = normalize_embeddings
Expand Down Expand Up @@ -107,6 +152,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
normalize_embeddings=self.normalize_embeddings,
onnx_execution_provider=self.onnx_execution_provider,
pooling_mode=self.pooling_mode.value,
model_kwargs=self.model_kwargs,
token=self.token.to_dict() if self.token else None,
)
Expand All @@ -122,13 +168,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder":
"""
Deserialize this component from a dictionary.
"""
data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"])
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"])
return default_from_dict(cls, data)

@component.output_types(embedding=List[float])
def run(self, text: str):
"""Embed a string.
"""
Embed a string.
:param text: The text to embed.
:return: The embeddings of the text.
Expand All @@ -147,7 +195,7 @@ def run(self, text: str):
text_to_embed = self.prefix + text + self.suffix

embedding = self.embedding_backend.embed(
texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings
texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings, pooling_mode=self.pooling_mode
)

return {"embedding": embedding}
Loading

0 comments on commit 9eb400f

Please sign in to comment.