-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Github Actions and resolve merge conflicts
- Loading branch information
Showing
12 changed files
with
1,208 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
root = true | ||
|
||
[*] | ||
indent_style = space | ||
indent_size = 4 | ||
insert_final_newline = true | ||
trim_trailing_whitespace = true | ||
end_of_line = lf | ||
charset = utf-8 | ||
max_line_length = 88 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: neo4j_genai PR | ||
on: pull_request | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out repository code | ||
uses: actions/checkout@v4 | ||
- name: Install Poetry | ||
uses: snok/install-poetry@v1 | ||
with: | ||
virtualenvs-create: true | ||
virtualenvs-in-project: true | ||
installer-parallel: true | ||
- name: Load cached venv | ||
id: cached-poetry-dependencies | ||
uses: actions/cache@v4 | ||
with: | ||
path: .venv | ||
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} | ||
- name: Install dependencies | ||
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' | ||
run: poetry install --no-interaction --no-root | ||
- name: Install root project | ||
run: poetry install --no-interaction | ||
- name: Check format and linting | ||
run: | | ||
poetry run ruff format --check . | ||
poetry run ruff check . | ||
- name: Run tests and check coverage | ||
run: | | ||
poetry run coverage run -m pytest | ||
poetry run coverage report --fail-under=85 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,9 @@ | ||
dist/ | ||
<<<<<<< HEAD | ||
======= | ||
**/__pycache__/* | ||
*.py[cod] | ||
.mypy_cache/ | ||
.coverage | ||
htmlcov/ | ||
>>>>>>> 4222460 (Update Github Actions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .client import GenAIClient | ||
|
||
|
||
__all__ = ["GenAIClient"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import List, Dict, Any, Optional | ||
from pydantic import ValidationError | ||
from neo4j import Driver | ||
from neo4j.exceptions import CypherSyntaxError | ||
from .embeddings import Embeddings | ||
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord | ||
|
||
|
||
class GenAIClient: | ||
""" | ||
Provides functionality to use Neo4j's GenAI features | ||
""" | ||
|
||
def __init__( | ||
self, | ||
driver: Driver, | ||
embeddings: Optional[Embeddings] = None, | ||
) -> None: | ||
self.driver = driver | ||
self._verify_version() | ||
self.embeddings = embeddings | ||
|
||
def _verify_version(self) -> None: | ||
""" | ||
Check if the connected Neo4j database version supports vector indexing. | ||
Queries the Neo4j database to retrieve its version and compares it | ||
against a target version (5.11.0) that is known to support vector | ||
indexing. Raises a ValueError if the connected Neo4j version is | ||
not supported. | ||
""" | ||
version = self.database_query("CALL dbms.components()")[0]["versions"][0] | ||
if "aura" in version: | ||
version_tuple = ( | ||
*tuple(map(int, version.split("-")[0].split("."))), | ||
0, | ||
) | ||
else: | ||
version_tuple = tuple(map(int, version.split("."))) | ||
|
||
target_version = (5, 11, 0) | ||
|
||
if version_tuple < target_version: | ||
raise ValueError( | ||
"Version index is only supported in Neo4j version 5.11 or greater" | ||
) | ||
|
||
def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]: | ||
""" | ||
This method sends a Cypher query to the connected Neo4j database | ||
and returns the results as a list of dictionaries. | ||
Args: | ||
query (str): The Cypher query to execute. | ||
params (Dict, optional): Dictionary of query parameters. Defaults to {}. | ||
Returns: | ||
List[Dict[str, Any]]: List of dictionaries containing the query results. | ||
""" | ||
with self.driver.session() as session: | ||
try: | ||
data = session.run(query, params) | ||
return [r.data() for r in data] | ||
except CypherSyntaxError as e: | ||
raise ValueError(f"Cypher Statement is not valid\n{e}") | ||
|
||
def create_index( | ||
self, | ||
name: str, | ||
label: str, | ||
property: str, | ||
dimensions: int, | ||
similarity_fn: str, | ||
) -> None: | ||
""" | ||
This method constructs a Cypher query and executes it | ||
to create a new vector index in Neo4j. | ||
See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex) | ||
Args: | ||
name (str): The unique name of the index. | ||
label (str): The node label to be indexed. | ||
property (str): The property key of a node which contains embedding values. | ||
dimensions (int): Vector embedding dimension | ||
similarity_fn (str): case-insensitive values for the vector similarity function: | ||
``euclidean`` or ``cosine``. | ||
Raises: | ||
ValueError: If validation of the input arguments fail. | ||
""" | ||
index_data = { | ||
"name": name, | ||
"label": label, | ||
"property": property, | ||
"dimensions": dimensions, | ||
"similarity_fn": similarity_fn, | ||
} | ||
try: | ||
index_data = CreateIndexModel(**index_data) | ||
except ValidationError as e: | ||
raise ValueError(f"Error for inputs to create_index {str(e)}") | ||
|
||
query = ( | ||
"CALL db.index.vector.createNodeIndex(" | ||
"$name," | ||
"$label," | ||
"$property," | ||
"toInteger($dimensions)," | ||
"$similarity_fn )" | ||
) | ||
self.database_query(query, params=index_data.model_dump()) | ||
|
||
def drop_index(self, name: str) -> None: | ||
""" | ||
This method constructs a Cypher query and executes it | ||
to drop a vector index in Neo4j. | ||
See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) | ||
Args: | ||
name (str): The name of the index to delete. | ||
""" | ||
query = "DROP INDEX $name" | ||
parameters = { | ||
"name": name, | ||
} | ||
self.database_query(query, params=parameters) | ||
|
||
def similarity_search( | ||
self, | ||
name: str, | ||
query_vector: Optional[List[float]] = None, | ||
query_text: Optional[str] = None, | ||
top_k: int = 5, | ||
) -> List[Neo4jRecord]: | ||
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. | ||
See the following documentation for more details: | ||
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) | ||
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) | ||
Args: | ||
name (str): Refers to the unique name of the vector index to query. | ||
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. | ||
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. | ||
top_k (int, optional): The number of neighbors to return. Defaults to 5. | ||
Raises: | ||
ValueError: If validation of the input arguments fail. | ||
ValueError: If no embeddings is provided. | ||
Returns: | ||
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. | ||
""" | ||
try: | ||
validated_data = SimilaritySearchModel( | ||
index_name=name, | ||
top_k=top_k, | ||
query_vector=query_vector, | ||
query_text=query_text, | ||
) | ||
except ValidationError as e: | ||
error_details = e.errors() | ||
raise ValueError(f"Validation failed: {error_details}") | ||
|
||
parameters = validated_data.model_dump(exclude_none=True) | ||
|
||
if query_text: | ||
if not self.embeddings: | ||
raise ValueError("Embedding method required for text query.") | ||
query_vector = self.embeddings.embed_query(query_text) | ||
parameters["query_vector"] = query_vector | ||
|
||
db_query_string = """ | ||
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) | ||
YIELD node, score | ||
""" | ||
records = self.database_query(db_query_string, params=parameters) | ||
|
||
try: | ||
return [Neo4jRecord(node=record["node"], score=record["score"]) for record in records] | ||
except ValidationError as e: | ||
error_details = e.errors() | ||
raise ValueError(f"Validation failed while constructing output: {error_details}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from abc import ABC, abstractmethod | ||
from .types import EmbeddingVector | ||
|
||
|
||
class Embeddings(ABC): | ||
"""Interface for embedding models.""" | ||
|
||
@abstractmethod | ||
def embed_query(self, text: str) -> EmbeddingVector: | ||
"""Embed query text. | ||
Args: | ||
text (str): Text to convert to vector embedding | ||
Returns: | ||
EmbeddingVector: A vector embedding. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List, Any, Literal, Optional | ||
from pydantic import BaseModel, PositiveInt, Field, model_validator | ||
|
||
|
||
class Neo4jRecord(BaseModel): | ||
node: Any | ||
score: float | ||
|
||
|
||
class EmbeddingVector(BaseModel): | ||
vector: List[float] | ||
|
||
|
||
class CreateIndexModel(BaseModel): | ||
name: str | ||
label: str | ||
property: str | ||
dimensions: int = Field(ge=1, le=2048) | ||
similarity_fn: Literal["euclidean", "cosine"] | ||
|
||
|
||
class SimilaritySearchModel(BaseModel): | ||
index_name: str | ||
top_k: PositiveInt = 5 | ||
query_vector: Optional[List[float]] = None | ||
query_text: Optional[str] = None | ||
|
||
@model_validator(mode="before") | ||
def check_query(cls, values): | ||
""" | ||
Validates that one of either query_vector or query_text is provided exclusively. | ||
""" | ||
query_vector, query_text = values.get("query_vector"), values.get("query_text") | ||
if not (bool(query_vector) ^ bool(query_text)): | ||
raise ValueError( | ||
"You must provide exactly one of query_vector or query_text." | ||
) | ||
return values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pytest | ||
from neo4j_genai import GenAIClient | ||
from unittest.mock import Mock, patch | ||
|
||
|
||
@pytest.fixture | ||
def driver(): | ||
return Mock() | ||
|
||
|
||
@pytest.fixture | ||
@patch("neo4j_genai.GenAIClient._verify_version") | ||
def client(_verify_version_mock, driver): | ||
return GenAIClient(driver) |
Oops, something went wrong.