Skip to content

Commit

Permalink
Merge branch 'main' into willtai/add-custom-retrieval-query
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai authored Mar 26, 2024
2 parents 259c2d6 + 5dccc2e commit 537d08f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 41 deletions.
32 changes: 14 additions & 18 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _verify_version(self) -> None:
Check if the connected Neo4j database version supports vector indexing.
Queries the Neo4j database to retrieve its version and compares it
against a target version (5.11.0) that is known to support vector
against a target version (5.18.1) that is known to support vector
indexing. Raises a ValueError if the connected Neo4j version is
not supported.
"""
Expand All @@ -36,14 +36,15 @@ def _verify_version(self) -> None:
*tuple(map(int, version.split("-")[0].split("."))),
0,
)
target_version = (5, 18, 0)
else:
version_tuple = tuple(map(int, version.split(".")))
target_version = (5, 18, 1)

target_version = (5, 11, 0)

if version_tuple < target_version:
raise ValueError(
"Version index is only supported in Neo4j version 5.11 or greater"
"This package only supports Neo4j version 5.18.1 or greater"
)

def create_index(
Expand Down Expand Up @@ -71,27 +72,22 @@ def create_index(
Raises:
ValueError: If validation of the input arguments fail.
"""
index_data = {
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
}
try:
index_data = CreateIndexModel(**index_data)
CreateIndexModel(**{
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
})
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")

query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
self.driver.execute_query(query, index_data.model_dump())
self.driver.execute_query(query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn})

def drop_index(self, name: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CreateIndexModel(BaseModel):
name: str
label: str
property: str
dimensions: int = PositiveInt
dimensions: PositiveInt
similarity_fn: Literal["euclidean", "cosine"]


Expand Down
44 changes: 22 additions & 22 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_genai_client_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.11-aura"]}], None, None]
driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None]

GenAIClient(driver=driver)

Expand All @@ -16,13 +16,13 @@ def test_genai_client_no_supported_aura_version(driver):
with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)

assert "Version index is only supported in Neo4j version 5.11 or greater" in str(
assert "This package only supports Neo4j version 5.18.1 or greater" in str(
excinfo
)


def test_genai_client_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.11.5"]}], None, None]
driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None]

GenAIClient(driver=driver)

Expand All @@ -33,34 +33,34 @@ def test_genai_client_no_supported_version(driver):
with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)

assert "Version index is only supported in Neo4j version 5.11 or greater" in str(
assert "This package only supports Neo4j version 5.18.1 or greater" in str(
excinfo
)


def test_create_index_happy_path(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = (
"CALL db.index.vector.createNodeIndex("
"$name,"
"$label,"
"$property,"
"toInteger($dimensions),"
"$similarity_fn )"
)
create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }")

client.create_index("my-index", "People", "name", 2048, "cosine")

driver.execute_query.assert_called_once_with(
create_query,
{
"name": "my-index",
"label": "People",
"property": "name",
"dimensions": 2048,
"similarity_fn": "cosine",
},
)
driver.execute_query.assert_called_once_with(create_query, {"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"})

def test_create_index_ensure_escaping(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }")

client.create_index("my-complicated-`-index", "People", "name", 2048, "cosine")

driver.execute_query.assert_called_once_with(create_query, {"name": "my-complicated-`-index", "dimensions": 2048, "similarity_fn": "cosine"})


def test_create_index_validation_error_dimensions_negative_integer(client):
with pytest.raises(ValueError) as excinfo:
client.create_index("my-index", "People", "name", -5, "cosine")
assert "Error for inputs to create_index" in str(excinfo)


def test_create_index_validation_error_dimensions(client):
Expand Down

0 comments on commit 537d08f

Please sign in to comment.