diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 085bc38a..77ff9860 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -24,7 +24,7 @@ def _verify_version(self) -> None: Check if the connected Neo4j database version supports vector indexing. Queries the Neo4j database to retrieve its version and compares it - against a target version (5.11.0) that is known to support vector + against a target version (5.18.1) that is known to support vector indexing. Raises a ValueError if the connected Neo4j version is not supported. """ @@ -36,14 +36,15 @@ def _verify_version(self) -> None: *tuple(map(int, version.split("-")[0].split("."))), 0, ) + target_version = (5, 18, 0) else: version_tuple = tuple(map(int, version.split("."))) + target_version = (5, 18, 1) - target_version = (5, 11, 0) if version_tuple < target_version: raise ValueError( - "Version index is only supported in Neo4j version 5.11 or greater" + "This package only supports Neo4j version 5.18.1 or greater" ) def create_index( @@ -71,27 +72,22 @@ def create_index( Raises: ValueError: If validation of the input arguments fail. """ - index_data = { - "name": name, - "label": label, - "property": property, - "dimensions": dimensions, - "similarity_fn": similarity_fn, - } try: - index_data = CreateIndexModel(**index_data) + CreateIndexModel(**{ + "name": name, + "label": label, + "property": property, + "dimensions": dimensions, + "similarity_fn": similarity_fn, + }) except ValidationError as e: raise ValueError(f"Error for inputs to create_index {str(e)}") query = ( - "CALL db.index.vector.createNodeIndex(" - "$name," - "$label," - "$property," - "toInteger($dimensions)," - "$similarity_fn )" + f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" ) - self.driver.execute_query(query, index_data.model_dump()) + self.driver.execute_query(query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}) def drop_index(self, name: str) -> None: """ diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 6649f782..e9fc80d4 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -10,7 +10,7 @@ class CreateIndexModel(BaseModel): name: str label: str property: str - dimensions: int = PositiveInt + dimensions: PositiveInt similarity_fn: Literal["euclidean", "cosine"] diff --git a/tests/test_client.py b/tests/test_client.py index 9da6b6e7..faab9da1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,7 @@ def test_genai_client_supported_aura_version(driver): - driver.execute_query.return_value = [[{"versions": ["5.11-aura"]}], None, None] + driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] GenAIClient(driver=driver) @@ -16,13 +16,13 @@ def test_genai_client_no_supported_aura_version(driver): with pytest.raises(ValueError) as excinfo: GenAIClient(driver=driver) - assert "Version index is only supported in Neo4j version 5.11 or greater" in str( + assert "This package only supports Neo4j version 5.18.1 or greater" in str( excinfo ) def test_genai_client_supported_version(driver): - driver.execute_query.return_value = [[{"versions": ["5.11.5"]}], None, None] + driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None] GenAIClient(driver=driver) @@ -33,34 +33,34 @@ def test_genai_client_no_supported_version(driver): with pytest.raises(ValueError) as excinfo: GenAIClient(driver=driver) - assert "Version index is only supported in Neo4j version 5.11 or greater" in str( + assert "This package only supports Neo4j version 5.18.1 or greater" in str( excinfo ) def test_create_index_happy_path(driver, client): driver.execute_query.return_value = [None, None, None] - create_query = ( - "CALL db.index.vector.createNodeIndex(" - "$name," - "$label," - "$property," - "toInteger($dimensions)," - "$similarity_fn )" - ) + create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }") client.create_index("my-index", "People", "name", 2048, "cosine") - driver.execute_query.assert_called_once_with( - create_query, - { - "name": "my-index", - "label": "People", - "property": "name", - "dimensions": 2048, - "similarity_fn": "cosine", - }, - ) + driver.execute_query.assert_called_once_with(create_query, {"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"}) + +def test_create_index_ensure_escaping(driver, client): + driver.execute_query.return_value = [None, None, None] + create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }") + + client.create_index("my-complicated-`-index", "People", "name", 2048, "cosine") + + driver.execute_query.assert_called_once_with(create_query, {"name": "my-complicated-`-index", "dimensions": 2048, "similarity_fn": "cosine"}) + + +def test_create_index_validation_error_dimensions_negative_integer(client): + with pytest.raises(ValueError) as excinfo: + client.create_index("my-index", "People", "name", -5, "cosine") + assert "Error for inputs to create_index" in str(excinfo) def test_create_index_validation_error_dimensions(client):