Skip to content

Commit

Permalink
Add embedding backend
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Feb 19, 2024
1 parent aa1d04d commit fe7fa36
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 180 deletions.
4 changes: 2 additions & 2 deletions integrations/optimum/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ classifiers = [
]
dependencies = [
"haystack-ai",
"transformers[sentencepiece]==4.36.2",
"transformers[sentencepiece]",
"sentence-transformers>=2.2.0",
"optimum[onnxruntime]==1.15.0"
"optimum[onnxruntime]"
]

[project.urls]
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, ClassVar, Dict, List, Optional, Union

import numpy as np
import torch
from haystack.utils.auth import Secret
from optimum.onnxruntime import ORTModelForFeatureExtraction
from tqdm import tqdm
from transformers import AutoTokenizer


class _OptimumEmbeddingBackendFactory:
"""
Factory class to create instances of Sentence Transformers embedding backends.
"""

_instances: ClassVar[Dict[str, "_OptimumEmbeddingBackend"]] = {}

@staticmethod
def get_embedding_backend(
model: str, token: Optional[Secret] = None, model_kwargs: Optional[Dict[str, Any]] = None
):
embedding_backend_id = f"{model}{token}"

if embedding_backend_id in _OptimumEmbeddingBackendFactory._instances:
return _OptimumEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _OptimumEmbeddingBackend(model=model, token=token, model_kwargs=model_kwargs)
_OptimumEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend


class _OptimumEmbeddingBackend:
"""
Class to manage Optimum embeddings.
"""

def __init__(self, model: str, token: Optional[Secret] = None, model_kwargs: Optional[Dict[str, Any]] = None):
# export=True converts the model to ONNX on the fly
self.model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True)
self.tokenizer = AutoTokenizer.from_pretrained(model, token=token)

def mean_pooling(self, model_output: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""
Perform Mean Pooling on the output of the Embedding model.
:param model_output: The output of the embedding model.
:param attention_mask: The attention mask of the tokenized text.
:return: The embeddings of the text after mean pooling.
"""
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

def embed(
self,
texts_to_embed: Union[str, List[str]],
normalize_embeddings: bool,
progress_bar: bool = False,
batch_size: int = 1,
) -> Union[List[List[float]], List[float]]:
"""
Embed text or list of texts using the Optimum model.
:param texts_to_embed: T
:param normalize_embeddings: Whether to normalize the embeddings to unit length.
:param progress_bar: Whether to show a progress bar or not, defaults to False.
:param batch_size: Batch size to use, defaults to 1.
:return: A single embedding if the input is a single string. A list of embeddings if the input is a list of
strings.
"""
if isinstance(texts_to_embed, str):
texts = [texts_to_embed]
else:
texts = texts_to_embed

# Determine device for tokenizer output
device = self.model.device

# Sorting by length
length_sorted_idx = np.argsort([-len(sen) for sen in texts])
sentences_sorted = [texts[idx] for idx in length_sorted_idx]

all_embeddings = []
for i in tqdm(
range(0, len(sentences_sorted), batch_size), disable=not progress_bar, desc="Calculating embeddings"
):
batch = sentences_sorted[i : i + batch_size]
encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)

# Only pass required inputs otherwise onnxruntime can raise an error
inputs_to_remove = set(encoded_input.keys()).difference(self.model.inputs_names)
for key in inputs_to_remove:
encoded_input.pop(key)

# Compute token embeddings
model_output = self.model(**encoded_input)

# Perform mean pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"].to(device))

all_embeddings.extend(sentence_embeddings.tolist())

# Reorder embeddings according to original order
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

# Normalize all embeddings
if normalize_embeddings:
all_embeddings = torch.nn.functional.normalize(torch.tensor(all_embeddings), p=2, dim=1).tolist()

if isinstance(texts_to_embed, str):
# Return the embedding if only one text was passed
all_embeddings = all_embeddings[0]

return all_embeddings
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from optimum.onnxruntime import ORTModelForFeatureExtraction
from tqdm import tqdm
from transformers import AutoTokenizer
from haystack_integrations.components.embedders.backends.optimum_backend import (
_OptimumEmbeddingBackendFactory,
)


class OptimumDocumentEmbedder:
Expand All @@ -24,7 +22,7 @@ class OptimumDocumentEmbedder:
doc = Document(content="I love pizza!")
document_embedder = OptimumDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
document_embedder = OptimumDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2")
document_embedder.warm_up()
result = document_embedder.run([doc])
Expand All @@ -48,7 +46,7 @@ class OptimumDocumentEmbedder:

def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
model: str = "sentence-transformers/all-mpnet-base-v2",
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008
prefix: str = "",
suffix: str = "",
Expand All @@ -63,13 +61,14 @@ def __init__(
"""
Create a OptimumDocumentEmbedder component.
:param model: A string representing the model id on HF Hub. Default is "BAAI/bge-small-en-v1.5".
:param model: A string representing the model id on HF Hub. Defaults to
"sentence-transformers/all-mpnet-base-v2".
:param token: The HuggingFace token to use as HTTP bearer authorization.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param normalize_embeddings: Whether to normalize the embeddings to unit length.
:param onnx_execution_provider: The execution provider to use for ONNX models. Defaults to
"CPUExecutionProvider".
"CPUExecutionProvider". See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
:param model_kwargs: Dictionary containing additional keyword arguments to pass to the model.
In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization
parameters.
Expand Down Expand Up @@ -107,26 +106,12 @@ def __init__(

def warm_up(self):
"""
Convert the model to ONNX.
The model is cached if the "TensorrtExecutionProvider" is used, since it takes a while to to build the TensorRT
engine.
Load the embedding backend.
"""
if self.embedding_model is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model)

if self.onnx_execution_provider == "TensorrtExecutionProvider":
# Cache engine for TensorRT
provider_options = {
"trt_engine_cache_enable": True,
"trt_engine_cache_path": f"tmp/trt_cache_{self.model}",
}
self.embedding_model = ORTModelForFeatureExtraction.from_pretrained(
**self.model_kwargs, use_cache=False, provider_options=provider_options
)
else:
# export=True converts the model to ONNX on the fly
self.embedding_model = ORTModelForFeatureExtraction.from_pretrained(**self.model_kwargs, export=True)
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _OptimumEmbeddingBackendFactory.get_embedding_backend(
model=self.model, token=self.token, model_kwargs=self.model_kwargs
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -162,21 +147,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder":
deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"])
return default_from_dict(cls, data)

def mean_pooling(self, model_output: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""
Perform Mean Pooling on the output of the Embedding model.
:param model_output: The output of the embedding model.
:param attention_mask: The attention mask of the tokenized text.
:return: The embeddings of the text after mean pooling.
"""
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Expand All @@ -194,47 +164,6 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
"""
Embed a list of texts in batches.
"""
# Determine device for tokenizer output
device = (
"cuda"
if self.onnx_execution_provider
in ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"]
else "cpu"
)

# Sorting by length
length_sorted_idx = np.argsort([-len(sen) for sen in texts_to_embed])
sentences_sorted = [texts_to_embed[idx] for idx in length_sorted_idx]

all_embeddings = []
for i in tqdm(
range(0, len(sentences_sorted), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = sentences_sorted[i : i + batch_size]
encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device) # type: ignore

# Compute token embeddings
with torch.no_grad():
model_output = self.embedding_model(**encoded_input) # type: ignore

# Perform mean pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"].cpu())

all_embeddings.extend(sentence_embeddings.tolist())

# Reorder embeddings according to original order
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

# Normalize all embeddings
if self.normalize_embeddings:
all_embeddings = torch.nn.functional.normalize(torch.tensor(all_embeddings), p=2, dim=1).tolist()

return all_embeddings

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand All @@ -251,7 +180,7 @@ def run(self, documents: List[Document]):
)
raise TypeError(msg)

if not (self.embedding_model and self.tokenizer):
if not hasattr(self, "embedding_backend"):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

Expand All @@ -261,7 +190,12 @@ def run(self, documents: List[Document]):

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
embeddings = self.embedding_backend.embed(
texts_to_embed=texts_to_embed,
normalize_embeddings=self.normalize_embeddings,
progress_bar=self.progress_bar,
batch_size=self.batch_size,
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb
Expand Down
Loading

0 comments on commit fe7fa36

Please sign in to comment.