Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ollama Text Embedder #194

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion integrations/ollama/src/ollama_haystack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from ollama_haystack.chat.chat_generator import OllamaChatGenerator
from ollama_haystack.embeddings.text_embedder import OllamaTextEmbedder
from ollama_haystack.generator import OllamaGenerator

__all__ = ["OllamaGenerator", "OllamaChatGenerator"]
__all__ = ["OllamaGenerator", "OllamaChatGenerator", "OllamaTextEmbedder"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import component


@component
class OllamaTextEmbedder:
def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model

def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

@component.output_types(embedding=List[float])
def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a given chat history.
:param text: Text to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary with the key "embedding" and a list of floats as the value
"""

payload = self._create_json_payload(text, generation_kwargs)

response = requests.post(url=self.url, json=payload, timeout=self.timeout)

response.raise_for_status()

return response.json()
32 changes: 32 additions & 0 deletions integrations/ollama/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from ollama_haystack import OllamaTextEmbedder


class TestOllamaTextEmbedder:
def test_init_defaults(self):
embedder = OllamaTextEmbedder()

assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"

def test_init(self):
embedder = OllamaTextEmbedder(
model="llama2",
url="http://my-custom-endpoint:11434/api/embeddings",
generation_kwargs={"temperature": 0.5},
timeout=3000,
)

assert embedder.timeout == 3000
assert embedder.generation_kwargs == {"temperature": 0.5}
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings"
assert embedder.model == "llama2"

def test_run(self):
embedder = OllamaTextEmbedder()

reply = embedder.run("hello")

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["embedding"])