diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index b1e9309c2..03dc301b9 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -53,6 +53,8 @@ class FastembedDocumentEmbedder: def __init__( self, model: str = "BAAI/bge-small-en-v1.5", + prefix: str = "", + suffix: str = "", batch_size: int = 256, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -63,6 +65,8 @@ def __init__( :param model: Local path or name of the model in Hugging Face's model hub, such as ``'BAAI/bge-small-en-v1.5'``. + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. :param batch_size: Number of strings to encode at once. :param progress_bar: If true, displays progress bar during embedding. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. @@ -70,6 +74,8 @@ def __init__( """ self.model_name = model + self.prefix = prefix + self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] @@ -82,6 +88,8 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model_name, + prefix=self.prefix, + suffix=self.suffix, batch_size=self.batch_size, progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, @@ -95,6 +103,19 @@ def warm_up(self): if not hasattr(self, "embedding_backend"): self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(model_name=self.model_name) + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + text_to_embed = [ + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix, + ] + + texts_to_embed.append(text_to_embed[0]) + return texts_to_embed + @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): """ @@ -113,16 +134,7 @@ def run(self, documents: List[Document]): # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - texts_to_embed = [] - for doc in documents: - meta_values_to_embed = [ - str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None - ] - text_to_embed = [ - self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]), - ] - - texts_to_embed.append(text_to_embed[0]) + texts_to_embed = self._prepare_texts_to_embed(documents=documents) embeddings = self.embedding_backend.embed( texts_to_embed, batch_size=self.batch_size, diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 3446f80d7..455a1f94b 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -31,6 +31,8 @@ class FastembedTextEmbedder: def __init__( self, model: str = "BAAI/bge-small-en-v1.5", + prefix: str = "", + suffix: str = "", batch_size: int = 256, progress_bar: bool = True, ): @@ -40,11 +42,15 @@ def __init__( :param model: Local path or name of the model in Fastembed's model hub, such as ``'BAAI/bge-small-en-v1.5'``. :param batch_size: Number of strings to encode at once. + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. """ # TODO add parallel self.model_name = model + self.prefix = prefix + self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar @@ -55,6 +61,8 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model_name, + prefix=self.prefix, + suffix=self.suffix, batch_size=self.batch_size, progress_bar=self.progress_bar, ) @@ -79,7 +87,7 @@ def run(self, text: str): msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - text_to_embed = [text] + text_to_embed = [self.prefix + text + self.suffix] embedding = list( self.embedding_backend.embed( text_to_embed, diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index be182183c..6dd1b6e52 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -15,6 +15,8 @@ def test_init_default(self): """ embedder = FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5") assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] @@ -26,12 +28,16 @@ def test_init_with_parameters(self): """ embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] @@ -47,6 +53,8 @@ def test_to_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "", + "suffix": "", "batch_size": 256, "progress_bar": True, "embedding_separator": "\n", @@ -60,6 +68,8 @@ def test_to_dict_with_custom_init_parameters(self): """ embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], @@ -70,6 +80,8 @@ def test_to_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, "meta_fields_to_embed": ["test_field"], @@ -85,6 +97,8 @@ def test_from_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "", + "suffix": "", "batch_size": 256, "progress_bar": True, "meta_fields_to_embed": [], @@ -93,6 +107,8 @@ def test_from_dict(self): } embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] @@ -106,6 +122,8 @@ def test_from_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder.FastembedDocumentEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, "meta_fields_to_embed": ["test_field"], @@ -114,6 +132,8 @@ def test_from_dict_with_custom_init_parameters(self): } embedder = default_from_dict(FastembedDocumentEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index 6327532e1..465f17976 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -15,6 +15,8 @@ def test_init_default(self): """ embedder = FastembedTextEmbedder(model="BAAI/bge-small-en-v1.5") assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 256 assert embedder.progress_bar is True @@ -24,10 +26,14 @@ def test_init_with_parameters(self): """ embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, ) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False @@ -41,6 +47,8 @@ def test_to_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "", + "suffix": "", "batch_size": 256, "progress_bar": True, }, @@ -52,6 +60,8 @@ def test_to_dict_with_custom_init_parameters(self): """ embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, ) @@ -60,6 +70,8 @@ def test_to_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, }, @@ -73,12 +85,16 @@ def test_from_dict(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "", + "suffix": "", "batch_size": 256, "progress_bar": True, }, } embedder = default_from_dict(FastembedTextEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 256 assert embedder.progress_bar is True @@ -90,12 +106,16 @@ def test_from_dict_with_custom_init_parameters(self): "type": "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder.FastembedTextEmbedder", # noqa "init_parameters": { "model": "BAAI/bge-small-en-v1.5", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, }, } embedder = default_from_dict(FastembedTextEmbedder, embedder_dict) assert embedder.model_name == "BAAI/bge-small-en-v1.5" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False