Skip to content

Commit

Permalink
Add text embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
AlistairLR112 committed Jan 14, 2024
1 parent c974c8a commit 0c0d9d5
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import component


@component
class OllamaTextEmbedder:
def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model

def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

@component.output_types(embedding=List[float])
def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a given chat history.
:param text: Text to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary with the key "embedding" and a list of floats as the value
"""

payload = self._create_json_payload(text, generation_kwargs)

response = requests.post(url=self.url, json=payload, timeout=self.timeout)

response.raise_for_status()

return response.json()
32 changes: 32 additions & 0 deletions integrations/ollama/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from ollama_haystack import OllamaTextEmbedder


class TestOllamaTextEmbedder:
def test_init_defaults(self):
embedder = OllamaTextEmbedder()

assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"

def test_init(self):
embedder = OllamaTextEmbedder(
model="llama2",
url="http://my-custom-endpoint:11434/api/embeddings",
generation_kwargs={"temperature": 0.5},
timeout=3000,
)

assert embedder.timeout == 3000
assert embedder.generation_kwargs == {"temperature": 0.5}
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings"
assert embedder.model == "llama2"

def test_run(self):
embedder = OllamaTextEmbedder()

reply = embedder.run("hello")

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["embedding"])

0 comments on commit 0c0d9d5

Please sign in to comment.