Skip to content

Commit

Permalink
feat: TransformerSimilarityRanker add batching across Documents durin…
Browse files Browse the repository at this point in the history
…g inference (#8344)

* First pass at adding batch support to TransformersSimilarityRanker

* Add test

* Add reno
  • Loading branch information
sjrl authored Sep 11, 2024
1 parent 675cf43 commit 7227bcf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
29 changes: 26 additions & 3 deletions haystack/components/rankers/transformers_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
import accelerate # pylint: disable=unused-import # the library is used but not directly referenced
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer


Expand All @@ -42,7 +43,7 @@ class TransformersSimilarityRanker:
```
"""

def __init__(
def __init__( # noqa: PLR0913
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None,
Expand All @@ -57,6 +58,7 @@ def __init__(
score_threshold: Optional[float] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
batch_size: int = 16,
):
"""
Creates an instance of TransformersSimilarityRanker.
Expand Down Expand Up @@ -93,6 +95,9 @@ def __init__(
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param batch_size:
The batch size to use for inference. The higher the batch size, the more memory is required.
If you run into memory issues, reduce the batch size.
:raises ValueError:
If `top_k` is not > 0.
Expand All @@ -117,6 +122,7 @@ def __init__(
model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs or {}
self.batch_size = batch_size

# Parameter validation
if self.scale_score and self.calibration_factor is None:
Expand Down Expand Up @@ -261,11 +267,28 @@ def run(
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
query_doc_pairs.append([self.query_prefix + query, self.document_prefix + text_to_embed])

features = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore
class _Dataset(Dataset):
def __init__(self, batch_encoding):
self.batch_encoding = batch_encoding

def __len__(self):
return len(self.batch_encoding["input_ids"])

def __getitem__(self, item):
return {key: self.batch_encoding.data[key][item] for key in self.batch_encoding.data.keys()}

batch_enc = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore
self.device.first_device.to_torch()
)
dataset = _Dataset(batch_enc)
inp_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

similarity_scores = []
with torch.inference_mode():
similarity_scores = self.model(**features).logits.squeeze(dim=1) # type: ignore
for features in inp_dataloader:
model_preds = self.model(**features).logits.squeeze(dim=1) # type: ignore
similarity_scores.extend(model_preds)
similarity_scores = torch.stack(similarity_scores)

if scale_score:
similarity_scores = torch.sigmoid(similarity_scores * calibration_factor)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
We added batching during inference time to the TransformerSimilarityRanker to help prevent OOMs when ranking large amounts of Documents.
60 changes: 55 additions & 5 deletions test/components/rankers/test_transformers_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from transformers.modeling_outputs import SequenceClassifierOutput

from haystack import ComponentError, Document
from haystack import Document
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice, DeviceMap
Expand Down Expand Up @@ -202,7 +202,9 @@ def test_from_dict_no_default_parameters(self):

@patch("torch.sigmoid")
@patch("torch.sort")
def test_embed_meta(self, mocked_sort, mocked_sigmoid):
@patch("torch.stack")
def test_embed_meta(self, mocked_stack, mocked_sort, mocked_sigmoid):
mocked_stack.return_value = torch.tensor([0])
mocked_sort.return_value = (None, torch.tensor([0]))
mocked_sigmoid.return_value = torch.tensor([0])
embedder = TransformersSimilarityRanker(
Expand Down Expand Up @@ -232,7 +234,9 @@ def test_embed_meta(self, mocked_sort, mocked_sigmoid):

@patch("torch.sigmoid")
@patch("torch.sort")
def test_prefix(self, mocked_sort, mocked_sigmoid):
@patch("torch.stack")
def test_prefix(self, mocked_stack, mocked_sort, mocked_sigmoid):
mocked_stack.return_value = torch.tensor([0])
mocked_sort.return_value = (None, torch.tensor([0]))
mocked_sigmoid.return_value = torch.tensor([0])
embedder = TransformersSimilarityRanker(
Expand Down Expand Up @@ -261,7 +265,9 @@ def test_prefix(self, mocked_sort, mocked_sigmoid):
)

@patch("torch.sort")
def test_scale_score_false(self, mocked_sort):
@patch("torch.stack")
def test_scale_score_false(self, mocked_stack, mocked_sort):
mocked_stack.return_value = torch.FloatTensor([-10.6859, -8.9874])
mocked_sort.return_value = (None, torch.tensor([0, 1]))
embedder = TransformersSimilarityRanker(model="model", scale_score=False)
embedder.model = MagicMock()
Expand All @@ -277,7 +283,9 @@ def test_scale_score_false(self, mocked_sort):
assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4)

@patch("torch.sort")
def test_score_threshold(self, mocked_sort):
@patch("torch.stack")
def test_score_threshold(self, mocked_stack, mocked_sort):
mocked_stack.return_value = torch.FloatTensor([0.955, 0.001])
mocked_sort.return_value = (None, torch.tensor([0, 1]))
embedder = TransformersSimilarityRanker(model="model", scale_score=False, score_threshold=0.1)
embedder.model = MagicMock()
Expand Down Expand Up @@ -359,6 +367,48 @@ def test_run(self, query, docs_before_texts, expected_first_text, scores):
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)

@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text,scores",
[
(
"City in Bosnia and Herzegovina",
["Berlin", "Belgrade", "Sarajevo"],
"Sarajevo",
[2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331],
),
(
"Machine learning",
["Python", "Bakery in Paris", "Tesla Giga Berlin"],
"Python",
[1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05],
),
(
"Cubist movement",
["Nirvana", "Pablo Picasso", "Coffee"],
"Pablo Picasso",
[1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05],
),
],
)
def test_run_small_batch_size(self, query, docs_before_texts, expected_first_text, scores):
"""
Test if the component ranks documents correctly.
"""
ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=2)
ranker.warm_up()
docs_before = [Document(content=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before)
docs_after = output["documents"]

assert len(docs_after) == 3
assert docs_after[0].content == expected_first_text

sorted_scores = sorted(scores, reverse=True)
assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)

# Returns an empty list if no documents are provided
@pytest.mark.integration
def test_returns_empty_list_if_no_documents_are_provided(self):
Expand Down

0 comments on commit 7227bcf

Please sign in to comment.