Skip to content

Commit

Permalink
Add OllamaLLM and OllamaEmbeddings classes (#231)
Browse files Browse the repository at this point in the history
* Add OllamaLLM and OllamaEmbeddings classes using the ollama python client

* Try removing import

* :(

* Add tests + reformat import in ollama embeddings for consistency with all other imports

* Fix after merge
  • Loading branch information
stellasia authored Dec 12, 2024
1 parent ff6862e commit 140a057
Show file tree
Hide file tree
Showing 16 changed files with 293 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.
- Ability to create a Pipeline or SimpleKGPipeline from a config file. See [the example](examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py).
- Added `OllamaLLM` and `OllamaEmbeddings` classes to make Ollama support more explicit. Implementations using the `OpenAILLM` and `OpenAIEmbeddings` classes will still work.

## Changed
- Updated LLM prompts to include stricter instructions for generating valid JSON.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ This package has some optional features that can be enabled using
the extra dependencies described below:

- LLM providers (at least one is required for RAG and KG Builder Pipeline):
- **ollama**: LLMs from Ollama
- **openai**: LLMs from OpenAI (including AzureOpenAI)
- **google**: LLMs from Vertex AI
- **cohere**: LLMs from Cohere
Expand Down
12 changes: 12 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ AzureOpenAIEmbeddings
.. autoclass:: neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings
:members:

OllamaEmbeddings
================

.. autoclass:: neo4j_graphrag.embeddings.ollama.OllamaEmbeddings
:members:

VertexAIEmbeddings
==================

Expand Down Expand Up @@ -286,6 +292,12 @@ AzureOpenAILLM
:members:
:undoc-members: get_messages, client_class, async_client_class

OllamaLLM
---------

.. autoclass:: neo4j_graphrag.llm.ollama_llm.OllamaLLM
:members:


VertexAILLM
-----------
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Extra dependencies can be installed with:
List of extra dependencies:

- LLM providers (at least one is required for RAG and KG Builder Pipeline):
- **ollama**: LLMs from Ollama
- **openai**: LLMs from OpenAI (including AzureOpenAI)
- **google**: LLMs from Vertex AI
- **cohere**: LLMs from Cohere
Expand Down
34 changes: 5 additions & 29 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,13 @@ See :ref:`coherellm`.
Using a Local Model via Ollama
-------------------------------

Similarly to the official OpenAI Python client, the `OpenAILLM` can be
used with Ollama. Assuming Ollama is running on the default address `127.0.0.1:11434`,
Assuming Ollama is running on the default address `127.0.0.1:11434`,
it can be queried using the following:

.. code:: python
from neo4j_graphrag.llm import OpenAILLM
llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini")
from neo4j_graphrag.llm import OllamaLLM
llm = OllamaLLM(model_name="orca-mini")
llm.invoke("say something")
Expand Down Expand Up @@ -428,6 +427,7 @@ Currently, this package supports the following embedders:
- :ref:`mistralaiembeddings`
- :ref:`cohereembeddings`
- :ref:`azureopenaiembeddings`
- :ref:`ollamaembeddings`

The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:

Expand All @@ -438,31 +438,7 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente
embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model
If another embedder is desired, a custom embedder can be created. For example, consider
the following implementation of an embedder that wraps the `OllamaEmbedding` model from LlamaIndex:

.. code:: python
from llama_index.embeddings.ollama import OllamaEmbedding
from neo4j_graphrag.embeddings.base import Embedder
class OllamaEmbedder(Embedder):
def __init__(self, ollama_embedding):
self.embedder = ollama_embedding
def embed_query(self, text: str) -> list[float]:
embedding = self.embedder.get_text_embedding_batch(
[text], show_progress=True
)
return embedding[0]
ollama_embedding = OllamaEmbedding(
model_name="llama3",
base_url="http://localhost:11434",
ollama_additional_kwargs={"mirostat": 0},
)
embedder = OllamaEmbedder(ollama_embedding)
vector = embedder.embed_query("some text")
If another embedder is desired, a custom embedder can be created, using the `Embedder` interface.


Other Vector Retriever Configuration
Expand Down
11 changes: 3 additions & 8 deletions examples/customize/embeddings/ollama_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""This example demonstrate how to embed a text into a vector
using OpenAI models and API.
using a local model served by Ollama.
"""

from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.embeddings import OllamaEmbeddings

# not used but needs to be provided
api_key = "ollama"

embeder = OpenAIEmbeddings(
base_url="http://localhost:11434/v1",
api_key=api_key,
embeder = OllamaEmbeddings(
model="<model_name>",
)
res = embeder.embed_query("my question")
Expand Down
11 changes: 5 additions & 6 deletions examples/customize/llms/ollama_llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
"""This example demonstrate how to invoke an LLM using a local model
served by Ollama.
"""

# not used but needs to be provided
api_key = "ollama"
from neo4j_graphrag.llm import LLMResponse, OllamaLLM

llm = OpenAILLM(
base_url="http://localhost:11434/v1",
llm = OllamaLLM(
model_name="<model_name>",
api_key=api_key,
)
res: LLMResponse = llm.invoke("What is the additive color model?")
print(res.content)
55 changes: 36 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }
json-repair = "^0.30.2"
types-pyyaml = "^6.0.12.20240917"
ollama = {version = "^0.4.4", optional = true}

[tool.poetry.group.dev.dependencies]
urllib3 = "<2"
Expand All @@ -69,6 +70,7 @@ pinecone = ["pinecone-client"]
google = ["google-cloud-aiplatform"]
cohere = ["cohere"]
anthropic = ["anthropic"]
ollama = ["ollama"]
openai = ["openai"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j_graphrag/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from .base import Embedder
from .cohere import CohereEmbeddings
from .mistral import MistralAIEmbeddings
from .ollama import OllamaEmbeddings
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .sentence_transformers import SentenceTransformerEmbeddings
from .vertexai import VertexAIEmbeddings

__all__ = [
"Embedder",
"SentenceTransformerEmbeddings",
"OllamaEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"VertexAIEmbeddings",
Expand Down
3 changes: 1 addition & 2 deletions src/neo4j_graphrag/embeddings/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
"""
embeddings_batch_response = self.mistral_client.embeddings.create(
model=self.model,
inputs=[text],
model=self.model, inputs=[text], **kwargs
)
if embeddings_batch_response is None or not embeddings_batch_response.data:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
Expand Down
65 changes: 65 additions & 0 deletions src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError


class OllamaEmbeddings(Embedder):
"""
Ollama embeddings class.
This class uses the ollama Python client to generate vector embeddings for text data.
Args:
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
"""

def __init__(self, model: str, **kwargs: Any) -> None:
try:
import ollama
except ImportError:
raise ImportError(
"Could not import ollama python client. "
"Please install it with `pip install ollama`."
)
self.model = model
self.client = ollama.Client(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an Ollama text embedding model.
Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional keyword arguments to pass to the Ollama client.
"""
embeddings_response = self.client.embed(
model=self.model,
input=text,
**kwargs,
)

if embeddings_response is None or embeddings_response.embeddings is None:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")

embedding = embeddings_response.embeddings
if not isinstance(embedding, list):
raise EmbeddingsGenerationError("Embedding is not a list of floats.")

return embedding
Loading

0 comments on commit 140a057

Please sign in to comment.