Skip to content

Commit

Permalink
Changed test to mock neo4j driver's execute_driver instead of the cli…
Browse files Browse the repository at this point in the history
…ent's _database_query()

Changed test_client.py to have mocks at execute_query level

Changed precommit config to include ruff linting and formatting

Update pre-commit

Corrected cypher query error test and dimensions constraint in CreateIndexModel data type

ruff formatting
  • Loading branch information
willtai committed Mar 13, 2024
1 parent f134f7d commit d1685dc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 116 deletions.
3 changes: 2 additions & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def embed_query(self, text: str) -> List[float]:
# Initialize the client
client = GenAIClient(driver, embedder)

client.drop_index(INDEX_NAME)
# Creating the index
client.create_index(
INDEX_NAME,
Expand All @@ -46,7 +47,7 @@ def embed_query(self, text: str) -> List[float]:
parameters = {
"vector": vector,
}
client.database_query(insert_query, params=parameters)
client._database_query(insert_query, params=parameters)

# Perform the similarity search for a text query
query_text = "hello world"
Expand Down
2 changes: 1 addition & 1 deletion examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
parameters = {
"vector": vector,
}
client.database_query(insert_query, params=parameters)
client._database_query(insert_query, params=parameters)

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
Expand Down
23 changes: 12 additions & 11 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _verify_version(self) -> None:
indexing. Raises a ValueError if the connected Neo4j version is
not supported.
"""
version = self.database_query("CALL dbms.components()")[0]["versions"][0]
version = self._database_query("CALL dbms.components()")[0]["versions"][0]
if "aura" in version:
version_tuple = (
*tuple(map(int, version.split("-")[0].split("."))),
Expand All @@ -45,7 +45,9 @@ def _verify_version(self) -> None:
"Version index is only supported in Neo4j version 5.11 or greater"
)

def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]:
def _database_query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
This method sends a Cypher query to the connected Neo4j database
and returns the results as a list of dictionaries.
Expand All @@ -57,12 +59,11 @@ def database_query(self, query: str, params: Dict = {}) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: List of dictionaries containing the query results.
"""
with self.driver.session() as session:
try:
data = session.run(query, params)
return [r.data() for r in data]
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid\n{e}")
try:
records, _, _ = self.driver.execute_query(query, params)
return records
except CypherSyntaxError as e:
raise ValueError(f"Cypher Statement is not valid\n{e}")

def create_index(
self,
Expand Down Expand Up @@ -109,7 +110,7 @@ def create_index(
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.model_dump())
self._database_query(query, params=index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
Expand All @@ -124,7 +125,7 @@ def drop_index(self, name: str) -> None:
parameters = {
"name": name,
}
self.database_query(query, params=parameters)
self._database_query(query, params=parameters)

def similarity_search(
self,
Expand Down Expand Up @@ -176,7 +177,7 @@ def similarity_search(
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records = self.database_query(db_query_string, params=parameters)
records = self._database_query(db_query_string, params=parameters)

try:
return [
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, Field, model_validator
from pydantic import BaseModel, PositiveInt, model_validator


class Neo4jRecord(BaseModel):
Expand All @@ -15,7 +15,7 @@ class CreateIndexModel(BaseModel):
name: str
label: str
property: str
dimensions: int = Field(ge=1, le=2048)
dimensions: int = PositiveInt
similarity_fn: Literal["euclidean", "cosine"]


Expand Down
14 changes: 10 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
from neo4j_genai import GenAIClient
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, patch
from typing import List
from neo4j_genai.embedder import Embedder


@pytest.fixture
def driver():
return Mock()
return MagicMock()


@pytest.fixture
Expand All @@ -20,8 +20,14 @@ def client(_verify_version_mock, driver):
@patch("neo4j_genai.GenAIClient._verify_version")
def client_with_embedder(_verify_version_mock, driver):
class CustomEmbedder(Embedder):
def __init__(self):
self.dimension = 1536

def embed_query(self, text: str) -> List[float]:
return [1.0 for _ in range(1536)]
return [1.0 for _ in range(self.dimension)]

def set_dimension(self, dimension: int):
self.dimension = dimension

embedder = CustomEmbedder()
return GenAIClient(driver, embedder)
return GenAIClient(driver, embedder), embedder
Loading

0 comments on commit d1685dc

Please sign in to comment.