Skip to content

Commit

Permalink
Update Github Actions and resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbesga authored and willtai committed Mar 11, 2024
1 parent 1c20c75 commit 0c921fe
Show file tree
Hide file tree
Showing 12 changed files with 1,208 additions and 2 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
root = true

[*]
indent_style = space
indent_size = 4
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
charset = utf-8
max_line_length = 88
34 changes: 34 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: neo4j_genai PR
on: pull_request

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install root project
run: poetry install --no-interaction
- name: Check format and linting
run: |
poetry run ruff format --check .
poetry run ruff check .
- name: Run tests and check coverage
run: |
poetry run coverage run -m pytest
poetry run coverage report --fail-under=85
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
dist/
<<<<<<< HEAD
=======
**/__pycache__/*
*.py[cod]
.mypy_cache/
.coverage
htmlcov/
>>>>>>> 4222460 (Update Github Actions)
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
731 changes: 730 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

32 changes: 31 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,40 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.7"
neo4j = "^5.17.0"
<<<<<<< HEAD

=======
types-requests = "^2.31.0.20240218"
pydantic = "^2.6.3"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
mypy = "^1.8.0"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pre-commit = { version = "^3.6.2", python = "^3.9" }
coverage = "^7.4.3"
ruff = "^0.3.0"
>>>>>>> 4222460 (Update Github Actions)

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

exclude = ["**/tests/"]
<<<<<<< HEAD
exclude = ["**/tests/"]
=======
exclude = ["**/tests/"]

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
"",
]

[tool.coverage.paths]
source = ["src"]

[tool.pylint."MESSAGES CONTROL"]
disable="C0114,C0115"
>>>>>>> 4222460 (Update Github Actions)
4 changes: 4 additions & 0 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .client import GenAIClient


__all__ = ["GenAIClient"]
184 changes: 184 additions & 0 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import List, Dict, Any, Optional
from pydantic import ValidationError
from neo4j import Driver
from neo4j.exceptions import CypherSyntaxError
from .embeddings import Embeddings
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord


class GenAIClient:
"""
Provides functionality to use Neo4j's GenAI features
"""

def __init__(
self,
driver: Driver,
embeddings: Optional[Embeddings] = None,
) -> None:
self.driver = driver
self._verify_version()
self.embeddings = embeddings

def _verify_version(self) -> None:
"""
Check if the connected Neo4j database version supports vector indexing.
Queries the Neo4j database to retrieve its version and compares it
against a target version (5.11.0) that is known to support vector
indexing. Raises a ValueError if the connected Neo4j version is
not supported.
"""
version = self.database_query("CALL dbms.components()")[0]["versions"][0]
if "aura" in version:
version_tuple = (
*tuple(map(int, version.split("-")[0].split("."))),
0,
)
else:
version_tuple = tuple(map(int, version.split(".")))

target_version = (5, 11, 0)

if version_tuple < target_version:
raise ValueError(
"Version index is only supported in Neo4j version 5.11 or greater"
)

def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]:
"""
This method sends a Cypher query to the connected Neo4j database
and returns the results as a list of dictionaries.
Args:
query (str): The Cypher query to execute.
params (Dict, optional): Dictionary of query parameters. Defaults to {}.
Returns:
List[Dict[str, Any]]: List of dictionaries containing the query results.
"""
with self.driver.session() as session:
try:
data = session.run(query, params)
return [r.data() for r in data]
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid\n{e}")

def create_index(
self,
name: str,
label: str,
property: str,
dimensions: int,
similarity_fn: str,
) -> None:
"""
This method constructs a Cypher query and executes it
to create a new vector index in Neo4j.
See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex)
Args:
name (str): The unique name of the index.
label (str): The node label to be indexed.
property (str): The property key of a node which contains embedding values.
dimensions (int): Vector embedding dimension
similarity_fn (str): case-insensitive values for the vector similarity function:
``euclidean`` or ``cosine``.
Raises:
ValueError: If validation of the input arguments fail.
"""
index_data = {
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
}
try:
index_data = CreateIndexModel(**index_data)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")

query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
This method constructs a Cypher query and executes it
to drop a vector index in Neo4j.
See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop)
Args:
name (str): The name of the index to delete.
"""
query = "DROP INDEX $name"
parameters = {
"name": name,
}
self.database_query(query, params=parameters)

def similarity_search(
self,
name: str,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
name (str): Refers to the unique name of the vector index to query.
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embeddings is provided.
Returns:
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
error_details = e.errors()
raise ValueError(f"Validation failed: {error_details}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embeddings:
raise ValueError("Embedding method required for text query.")
query_vector = self.embeddings.embed_query(query_text)
parameters["query_vector"] = query_vector

db_query_string = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records = self.database_query(db_query_string, params=parameters)

try:
return [Neo4jRecord(node=record["node"], score=record["score"]) for record in records]
except ValidationError as e:
error_details = e.errors()
raise ValueError(f"Validation failed while constructing output: {error_details}")
17 changes: 17 additions & 0 deletions src/neo4j_genai/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from .types import EmbeddingVector


class Embeddings(ABC):
"""Interface for embedding models."""

@abstractmethod
def embed_query(self, text: str) -> EmbeddingVector:
"""Embed query text.
Args:
text (str): Text to convert to vector embedding
Returns:
EmbeddingVector: A vector embedding.
"""
38 changes: 38 additions & 0 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List, Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, Field, model_validator


class Neo4jRecord(BaseModel):
node: Any
score: float


class EmbeddingVector(BaseModel):
vector: List[float]


class CreateIndexModel(BaseModel):
name: str
label: str
property: str
dimensions: int = Field(ge=1, le=2048)
similarity_fn: Literal["euclidean", "cosine"]


class SimilaritySearchModel(BaseModel):
index_name: str
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None

@model_validator(mode="before")
def check_query(cls, values):
"""
Validates that one of either query_vector or query_text is provided exclusively.
"""
query_vector, query_text = values.get("query_vector"), values.get("query_text")
if not (bool(query_vector) ^ bool(query_text)):
raise ValueError(
"You must provide exactly one of query_vector or query_text."
)
return values
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch


@pytest.fixture
def driver():
return Mock()


@pytest.fixture
@patch("neo4j_genai.GenAIClient._verify_version")
def client(_verify_version_mock, driver):
return GenAIClient(driver)
Loading

0 comments on commit 0c921fe

Please sign in to comment.