Skip to content

Commit

Permalink
refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Nov 28, 2024
1 parent 47d1db3 commit dfb116f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(
model: str = "nomic-embed-text",
url: str = "http://localhost:11434",
generation_kwargs: Optional[Dict[str, Any]] = None,
batch_size = 32,
timeout: int = 120,
prefix: str = "",
suffix: str = "",
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
batch_size: int = 32,
):
"""
:param model:
Expand All @@ -49,6 +49,18 @@ def __init__(
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param prefix:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param progress_bar:
If `True`, shows a progress bar when running.
:param meta_fields_to_embed:
List of metadata fields to embed along with the document text.
:param embedding_separator:
Separator used to concatenate the metadata fields to the document text.
:param batch_size:
Number of documents to process at once.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
Expand Down Expand Up @@ -89,13 +101,10 @@ def _embed_batch(
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
):
"""
Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1.
If this changes in the future, line 86 (the first line within the for loop), can contain:
batch = texts_to_embed[i + i + batch_size]
Internal method to embed a batch of texts.
"""

all_embeddings = []
meta: Dict[str, Any] = {"model": ""}

for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
Expand All @@ -104,9 +113,7 @@ def _embed_batch(
result = self._client.embed(model=self.model, input=batch, options=generation_kwargs)
all_embeddings.extend(result["embeddings"])

meta["model"] = self.model

return all_embeddings, meta
return all_embeddings

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
Expand All @@ -130,12 +137,14 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
)
raise TypeError(msg)

generation_kwargs = generation_kwargs or self.generation_kwargs

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(
embeddings = self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": meta}
return {"documents": documents, "meta": {"model": self.model}}
21 changes: 6 additions & 15 deletions integrations/ollama/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,14 @@ def import_text_in_embedder(self):

@pytest.mark.integration
def test_run(self):
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=1)
list_of_docs = [Document(content="This is a document containing some text.")]
reply = embedder.run(list_of_docs)

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["documents"][0].embedding)
assert reply["meta"]["model"] == "nomic-embed-text"

@pytest.mark.integration
def test_bulk_run(self):
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=3)
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2)
list_of_docs = [
Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."),
Document(content="The Andes mountains are the natural habitat of many llamas."),
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
]
reply = embedder.run(list_of_docs)
assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["documents"][0].embedding)
assert reply["meta"]["model"] == "nomic-embed-text"
result = embedder.run(list_of_docs)
assert result["meta"]["model"] == "nomic-embed-text"
documents = result["documents"]
assert len(documents) == 3
assert all(isinstance(element, float) for document in documents for element in document.embedding)

0 comments on commit dfb116f

Please sign in to comment.