Skip to content

Commit

Permalink
FastembedTextEmbedder - remove batch_size (#688)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored May 6, 2024
1 parent da46c9c commit d30c0ea
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ class FastembedDocumentEmbedder:
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
document_list = [
Document(
content="Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint destruction. Radical species with oxidative activity, including reactive nitrogen species, represent mediators of inflammation and cartilage damage.",
content=("Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint "
"destruction. Radical species with oxidative activity, including reactive nitrogen species, "
"represent mediators of inflammation and cartilage damage."),
meta={
"pubid": "25,445,628",
"long_answer": "yes",
},
),
Document(
content="Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion and actions are still poorly understood.",
content=("Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic "
"islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion "
"and actions are still poorly understood."),
meta={
"pubid": "25,445,712",
"long_answer": "yes",
Expand All @@ -49,7 +53,7 @@ class FastembedDocumentEmbedder:
print(f"Document Embedding: {result['documents'][0].embedding}")
print(f"Embedding Dimension: {len(result['documents'][0].embedding)}")
```
""" # noqa: E501
"""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,44 @@ class FastembedSparseDocumentEmbedder:
Usage example:
```python
# To use this component, install the "fastembed-haystack" package.
# pip install fastembed-haystack
from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder
from haystack.dataclasses import Document
doc_embedder = FastembedSparseDocumentEmbedder(
sparse_doc_embedder = FastembedSparseDocumentEmbedder(
model="prithvida/Splade_PP_en_v1",
batch_size=32,
)
doc_embedder.warm_up()
sparse_doc_embedder.warm_up()
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
document_list = [
Document(
content="Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint destruction. Radical species with oxidative activity, including reactive nitrogen species, represent mediators of inflammation and cartilage damage.",
content=("Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint "
"destruction. Radical species with oxidative activity, including reactive nitrogen species, "
"represent mediators of inflammation and cartilage damage."),
meta={
"pubid": "25,445,628",
"long_answer": "yes",
},
),
Document(
content="Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion and actions are still poorly understood.",
content=("Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic "
"islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion "
"and actions are still poorly understood."),
meta={
"pubid": "25,445,712",
"long_answer": "yes",
},
),
]
result = doc_embedder.run(document_list)
result = sparse_doc_embedder.run(document_list)
print(f"Document Text: {result['documents'][0].content}")
print(f"Document Embedding: {result['documents'][0].sparse_embedding}")
print(f"Embedding Dimension: {len(result['documents'][0].sparse_embedding)}")
print(f"Document Sparse Embedding: {result['documents'][0].sparse_embedding}")
print(f"Sparse Embedding Dimension: {len(result['documents'][0].sparse_embedding)}")
```
""" # noqa: E501
"""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,25 @@ class FastembedSparseTextEmbedder:
Usage example:
```python
# To use this component, install the "fastembed-haystack" package.
# pip install fastembed-haystack
from haystack_integrations.components.embedders.fastembed import FastembedSparseTextEmbedder
text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!"
text = ("It clearly says online this will work on a Mac OS system. "
"The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!")
text_embedder = FastembedSparseTextEmbedder(
sparse_text_embedder = FastembedSparseTextEmbedder(
model="prithvida/Splade_PP_en_v1"
)
text_embedder.warm_up()
sparse_text_embedder.warm_up()
embedding = text_embedder.run(text)["embedding"]
sparse_embedding = sparse_text_embedder.run(text)["sparse_embedding"]
```
""" # noqa: E501
"""

def __init__(
self,
model: str = "prithvida/Splade_PP_en_v1",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
batch_size: int = 32,
progress_bar: bool = True,
parallel: Optional[int] = None,
):
Expand All @@ -46,7 +43,6 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
Expand All @@ -57,7 +53,6 @@ def __init__(
self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel

Expand All @@ -73,7 +68,6 @@ def to_dict(self) -> Dict[str, Any]:
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
)
Expand Down Expand Up @@ -110,7 +104,6 @@ def run(self, text: str):

embedding = self.embedding_backend.embed(
[text],
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
parallel=self.parallel,
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ class FastembedTextEmbedder:
Usage example:
```python
# To use this component, install the "fastembed-haystack" package.
# pip install fastembed-haystack
from haystack_integrations.components.embedders.fastembed import FastembedTextEmbedder
text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!"
text = ("It clearly says online this will work on a Mac OS system. "
"The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!")
text_embedder = FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5"
Expand All @@ -26,7 +24,7 @@ class FastembedTextEmbedder:
embedding = text_embedder.run(text)["embedding"]
```
""" # noqa: E501
"""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_init_default(self):
assert embedder.model_name == "prithvida/Splade_PP_en_v1"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -30,14 +29,12 @@ def test_init_with_parameters(self):
model="prithvida/Splade_PP_en_v1",
cache_dir="fake_dir",
threads=2,
batch_size=64,
progress_bar=False,
parallel=1,
)
assert embedder.model_name == "prithvida/Splade_PP_en_v1"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand All @@ -53,7 +50,6 @@ def test_to_dict(self):
"model": "prithvida/Splade_PP_en_v1",
"cache_dir": None,
"threads": None,
"batch_size": 32,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -67,7 +63,6 @@ def test_to_dict_with_custom_init_parameters(self):
model="prithvida/Splade_PP_en_v1",
cache_dir="fake_dir",
threads=2,
batch_size=64,
progress_bar=False,
parallel=1,
)
Expand All @@ -78,7 +73,6 @@ def test_to_dict_with_custom_init_parameters(self):
"model": "prithvida/Splade_PP_en_v1",
"cache_dir": "fake_dir",
"threads": 2,
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -94,7 +88,6 @@ def test_from_dict(self):
"model": "prithvida/Splade_PP_en_v1",
"cache_dir": None,
"threads": None,
"batch_size": 32,
"progress_bar": True,
"parallel": None,
},
Expand All @@ -103,7 +96,6 @@ def test_from_dict(self):
assert embedder.model_name == "prithvida/Splade_PP_en_v1"
assert embedder.cache_dir is None
assert embedder.threads is None
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.parallel is None

Expand All @@ -117,7 +109,6 @@ def test_from_dict_with_custom_init_parameters(self):
"model": "prithvida/Splade_PP_en_v1",
"cache_dir": "fake_dir",
"threads": 2,
"batch_size": 64,
"progress_bar": False,
"parallel": 1,
},
Expand All @@ -126,7 +117,6 @@ def test_from_dict_with_custom_init_parameters(self):
assert embedder.model_name == "prithvida/Splade_PP_en_v1"
assert embedder.cache_dir == "fake_dir"
assert embedder.threads == 2
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.parallel == 1

Expand Down

0 comments on commit d30c0ea

Please sign in to comment.