diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index baebe7a02..8d2f5f505 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -30,13 +30,13 @@ def __init__( model: str = "nomic-embed-text", url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, - batch_size = 32, timeout: int = 120, prefix: str = "", suffix: str = "", progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + batch_size: int = 32, ): """ :param model: @@ -49,6 +49,18 @@ def __init__( [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param batch_size: + Number of documents to process at once. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} @@ -89,13 +101,10 @@ def _embed_batch( self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None ): """ - Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. - If this changes in the future, line 86 (the first line within the for loop), can contain: - batch = texts_to_embed[i + i + batch_size] + Internal method to embed a batch of texts. """ all_embeddings = [] - meta: Dict[str, Any] = {"model": ""} for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" @@ -104,9 +113,7 @@ def _embed_batch( result = self._client.embed(model=self.model, input=batch, options=generation_kwargs) all_embeddings.extend(result["embeddings"]) - meta["model"] = self.model - - return all_embeddings, meta + return all_embeddings @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): @@ -130,12 +137,14 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A ) raise TypeError(msg) + generation_kwargs = generation_kwargs or self.generation_kwargs + texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( + embeddings = self._embed_batch( texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs ) for doc, emb in zip(documents, embeddings): doc.embedding = emb - return {"documents": documents, "meta": meta} + return {"documents": documents, "meta": {"model": self.model}} diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index f4a651741..7d972e898 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -43,23 +43,14 @@ def import_text_in_embedder(self): @pytest.mark.integration def test_run(self): - embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=1) - list_of_docs = [Document(content="This is a document containing some text.")] - reply = embedder.run(list_of_docs) - - assert isinstance(reply, dict) - assert all(isinstance(element, float) for element in reply["documents"][0].embedding) - assert reply["meta"]["model"] == "nomic-embed-text" - - @pytest.mark.integration - def test_bulk_run(self): - embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=3) + embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2) list_of_docs = [ Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."), Document(content="The Andes mountains are the natural habitat of many llamas."), Document(content="Llamas have been used as pack animals for centuries, especially in South America."), ] - reply = embedder.run(list_of_docs) - assert isinstance(reply, dict) - assert all(isinstance(element, float) for element in reply["documents"][0].embedding) - assert reply["meta"]["model"] == "nomic-embed-text" + result = embedder.run(list_of_docs) + assert result["meta"]["model"] == "nomic-embed-text" + documents = result["documents"] + assert len(documents) == 3 + assert all(isinstance(element, float) for document in documents for element in document.embedding)