Skip to content

Commit

Permalink
Update Github Actions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbesga committed Mar 4, 2024
1 parent 9be9e5e commit a009b40
Show file tree
Hide file tree
Showing 12 changed files with 500 additions and 100 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
root = true

[*]
indent_style = space
indent_size = 4
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
charset = utf-8
max_line_length = 88
34 changes: 34 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: neo4j_genai PR
on: pull_request

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install root project
run: poetry install --no-interaction
- name: Check format and linting
run: |
poetry run ruff format --check .
poetry run ruff check .
- name: Run tests and check coverage
run: |
poetry run coverage run -m pytest
poetry run coverage report --fail-under=85
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
dist/
**/__pycache__/*
*.py[cod]
.mypy_cache/
.mypy_cache/
.coverage
htmlcov/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 # Use the ref you want to point at
rev: v4.5.0
hooks:
- id: trailing-whitespace
# - id: ...
- id: end-of-file-fixer
345 changes: 272 additions & 73 deletions poetry.lock

Large diffs are not rendered by default.

32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,31 @@ from = "src"
python = "^3.8"
neo4j = "^5.17.0"
types-requests = "^2.31.0.20240218"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pydantic = "^2.6.3"

[tool.poetry.group.dev.dependencies]
pylint = "^3.1.0"
mypy = "^1.8.0"
black = "^24.2.0"
pytest = "^8.0.2"
pytest-mock = "^3.12.0"
pre-commit = { version = "^3.6.2", python = "^3.9" }
coverage = "^7.4.3"
ruff = "^0.3.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

exclude = ["**/tests/"]
[tool.black]
line-length = 88
target-version = ['py38']
include = '\.pyi?$'
exclude = '''
/(
.git
| .venv
| build
| dist
)/
'''

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
"",
]

[tool.coverage.paths]
source = ["src"]

[tool.pylint."MESSAGES CONTROL"]
disable="C0114,C0115"
4 changes: 4 additions & 0 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .client import GenAIClient


__all__ = ["GenAIClient"]
13 changes: 8 additions & 5 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pydantic import ValidationError
from neo4j import Driver
from neo4j.exceptions import CypherSyntaxError
from neo4j_genai.embeddings import Embeddings
from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord
from .embeddings import Embeddings
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord


class GenAIClient:
Expand Down Expand Up @@ -31,7 +31,10 @@ def _verify_version(self) -> None:
"""
version = self.database_query("CALL dbms.components()")[0]["versions"][0]
if "aura" in version:
version_tuple = (*tuple(map(int, version.split("-")[0].split("."))), 0)
version_tuple = (
*tuple(map(int, version.split("-")[0].split("."))),
0,
)
else:
version_tuple = tuple(map(int, version.split(".")))

Expand Down Expand Up @@ -106,7 +109,7 @@ def create_index(
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.dict())
self.database_query(query, params=index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
Expand Down Expand Up @@ -160,7 +163,7 @@ def similarity_search(
error_details = e.errors()
raise ValueError(f"Validation failed: {error_details}")

parameters = validated_data.dict(exclude_none=True)
parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embeddings:
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from neo4j_genai.types import EmbeddingVector
from .types import EmbeddingVector


class Embeddings(ABC):
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, Field, root_validator
from pydantic import BaseModel, PositiveInt, Field, model_validator


class Neo4jRecord(BaseModel):
Expand All @@ -15,7 +15,7 @@ class CreateIndexModel(BaseModel):
name: str
label: str
property: str
dimensions: int = Field(ge=1, le=20)
dimensions: int = Field(ge=1, le=2048)
similarity_fn: Literal["euclidean", "cosine"]


Expand All @@ -25,7 +25,7 @@ class SimilaritySearchModel(BaseModel):
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None

@root_validator(pre=True)
@model_validator(mode="before")
def check_query(cls, values):
"""
Validates that one of either query_vector or query_text is provided exclusively.
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch


@pytest.fixture
def driver():
return Mock()


@pytest.fixture
@patch("neo4j_genai.GenAIClient._verify_version")
def client(_verify_version_mock, driver):
return GenAIClient(driver)
132 changes: 132 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch
from neo4j.exceptions import CypherSyntaxError


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.11-aura"]}],
)
def test_genai_client_supported_aura_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.3-aura"]}],
)
def test_genai_client_no_supported_aura_version(driver):
with pytest.raises(ValueError):
GenAIClient(driver)


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["5.11.5"]}],
)
def test_genai_client_supported_version(mock_database_query, driver):
GenAIClient(driver)
mock_database_query.assert_called_once()


@patch(
"neo4j_genai.GenAIClient.database_query",
return_value=[{"versions": ["4.3.5"]}],
)
def test_genai_client_no_supported_version(driver):
with pytest.raises(ValueError):
GenAIClient(driver)


@patch("neo4j_genai.GenAIClient.database_query")
def test_create_index_happy_path(mock_database_query, client):
client.create_index("my-index", "People", "name", 2048, "cosine")
query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
)
mock_database_query.assert_called_once_with(
query,
params={
"name": "my-index",
"label": "People",
"property": "name",
"dimensions": 2048,
"similarity_fn": "cosine",
},
)


def test_create_index_too_big_dimension(client):
with pytest.raises(ValueError):
client.create_index("my-index", "People", "name", 5024, "cosine")


def test_create_index_validation_error_dimensions(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", "no-dim", "cosine")
assert "Error for inputs to create_index" in str(excinfo)


def test_create_index_validation_error_similarity_fn(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", "no-dim", "algebra")
assert "Error for inputs to create_index" in str(excinfo)


@patch("neo4j_genai.GenAIClient.database_query")
def test_drop_index(mock_database_query, client):
client.drop_index("my-index")

query = "DROP INDEX $name"

mock_database_query.assert_called_with(query, params={"name": "my-index"})


def test_database_query_happy(client, driver):
class Session:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
pass

def run(self, query, params):
m_list = []
for i in range(3):
mock = Mock()
mock.data.return_value = i
m_list.append(mock)

return m_list

driver.session = Session
res = client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})
assert res == [0, 1, 2]


def test_database_query_cypher_error(client, driver):
class Session:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
pass

def run(self, query, params):
raise CypherSyntaxError

driver.session = Session

with pytest.raises(ValueError):
client.database_query("MATCH (p:$label) RETURN p", {"label": "People"})


def test_similarity_search():
pass

0 comments on commit a009b40

Please sign in to comment.