Skip to content

Commit

Permalink
fastembed fix: added prefix and suffix (#390)
Browse files Browse the repository at this point in the history
* created project

* added parallel param

* updated test

* version 0.0.1

* renamed folder

* removed print

* updated readme

* added fastembed.yml

* fix typos

* python version to 3.9 for lint

* updated file

* force install black

* return to original file

* try to fix workflow

* retry

* add missing info to pyproject

* add hatch-vcs to check version

* Update pyproject.toml

* fixed typos

* removed python 3.9

* Update fastembed.yml

* Update fastembed_document_embedder.py

* Update fastembed_text_embedder.py

* ignore errors for bool arguments

* fix

* try moving noqa

* move noqa

* formatted with black

* added numpy dependency

* removed numpy

* removed numpy

* make mypy happy

* Update fastembed_backend.py

* removed classvar

* fix

* Update pyproject.toml

* added import numpy lint

* skip docs generation for the time being

* Update README.md

* added config.yml

* generate docs

* Update fastembed.yml

* Update config.yml

* rm unnecessary from_dict

* final touch

* updated labeler.yml

* updated library readme

* fix typos

* fix docstrings/README

* added prefix and suffix

* fixed typos

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
nickprock and anakin87 authored Feb 11, 2024
1 parent d857f52 commit 1a1f5a2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
Expand All @@ -63,13 +65,17 @@ def __init__(
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name = model
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
Expand All @@ -82,6 +88,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
Expand All @@ -95,6 +103,19 @@ def warm_up(self):
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix,
]

texts_to_embed.append(text_to_embed[0])
return texts_to_embed

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand All @@ -113,16 +134,7 @@ def run(self, documents: List[Document]):

# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here

texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = [
self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]),
]

texts_to_embed.append(text_to_embed[0])
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class FastembedTextEmbedder:
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
):
Expand All @@ -40,11 +42,15 @@ def __init__(
:param model: Local path or name of the model in Fastembed's model hub,
such as ``'BAAI/bge-small-en-v1.5'``.
:param batch_size: Number of strings to encode at once.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""

# TODO add parallel

self.model_name = model
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar

Expand All @@ -55,6 +61,8 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
)
Expand All @@ -79,7 +87,7 @@ def run(self, text: str):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

text_to_embed = [text]
text_to_embed = [self.prefix + text + self.suffix]
embedding = list(
self.embedding_backend.embed(
text_to_embed,
Expand Down
20 changes: 20 additions & 0 deletions integrations/fastembed/tests/test_fastembed_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
Expand All @@ -26,12 +28,16 @@ def test_init_with_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
Expand All @@ -47,6 +53,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"embedding_separator": "\n",
Expand All @@ -60,6 +68,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
Expand All @@ -70,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
Expand All @@ -85,6 +97,8 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
"meta_fields_to_embed": [],
Expand All @@ -93,6 +107,8 @@ def test_from_dict(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
Expand All @@ -106,6 +122,8 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
Expand All @@ -114,6 +132,8 @@ def test_from_dict_with_custom_init_parameters(self):
}
embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
Expand Down
20 changes: 20 additions & 0 deletions integrations/fastembed/tests/test_fastembed_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_init_default(self):
"""
embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5")
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True

Expand All @@ -24,10 +26,14 @@ def test_init_with_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False

Expand All @@ -41,6 +47,8 @@ def test_to_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
},
Expand All @@ -52,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self):
"""
embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
)
Expand All @@ -60,6 +70,8 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
},
Expand All @@ -73,12 +85,16 @@ def test_from_dict(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "",
"suffix": "",
"batch_size": 256,
"progress_bar": True,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 256
assert embedder.progress_bar is True

Expand All @@ -90,12 +106,16 @@ def test_from_dict_with_custom_init_parameters(self):
"type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
},
}
embedder = default_from_dict(FastembedTextEmbedder, embedder_dict)
assert embedder.model_name == "BAAI/bge-small-en-v1.5"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False

Expand Down

0 comments on commit 1a1f5a2

Please sign in to comment.