From 88ffdf0d9690f6a90e880acdb8e27b330d9faeb3 Mon Sep 17 00:00:00 2001 From: willtai Date: Mon, 13 May 2024 10:22:14 +0100 Subject: [PATCH] Add try catch for create_index and rename imports of neo4j (#30) --- src/neo4j_genai/indexes.py | 83 +++++++++++++++++++--------- src/neo4j_genai/retrievers/base.py | 4 +- src/neo4j_genai/retrievers/hybrid.py | 14 ++--- src/neo4j_genai/retrievers/vector.py | 10 ++-- src/neo4j_genai/types.py | 4 +- tests/e2e/conftest.py | 10 +++- tests/e2e/test_hybrid_e2e.py | 12 ++-- tests/unit/conftest.py | 4 +- tests/unit/test_indexes.py | 34 ++++++++++-- 9 files changed, 118 insertions(+), 57 deletions(-) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 132cc1448..32ace5784 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j import Driver +import neo4j from pydantic import ValidationError from .types import VectorIndexModel, FulltextIndexModel +import logging + + +logger = logging.getLogger(__name__) def create_vector_index( - driver: Driver, + driver: neo4j.Driver, name: str, label: str, property: str, @@ -32,8 +36,11 @@ def create_vector_index( See Cypher manual on [Create vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#indexes-vector-create) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. property (str): The property key of a node which contains embedding values. @@ -43,6 +50,7 @@ def create_vector_index( Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of vector index fails. """ try: VectorIndexModel( @@ -58,17 +66,23 @@ def create_vector_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_vector_index {str(e)}") - query = ( - f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" - ) - driver.execute_query( - query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn} - ) + try: + query = ( + f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" + ) + logger.info(f"Creating vector index named '{name}'") + driver.execute_query( + query, + {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, + ) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j vector index creation failed {e}") + raise def create_fulltext_index( - driver: Driver, name: str, label: str, node_properties: list[str] + driver: neo4j.Driver, name: str, label: str, node_properties: list[str] ) -> None: """ This method constructs a Cypher query and executes it @@ -76,14 +90,18 @@ def create_fulltext_index( See Cypher manual on [Create fulltext index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/full-text-indexes/#create-full-text-indexes) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. node_properties (list[str]): The node properties to create the fulltext index on. Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of fulltext index fails. """ try: FulltextIndexModel( @@ -97,26 +115,39 @@ def create_fulltext_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") - query = ( - "CREATE FULLTEXT INDEX $name " - f"FOR (n:`{label}`) ON EACH " - f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" - ) - driver.execute_query(query, {"name": name}) + try: + query = ( + "CREATE FULLTEXT INDEX $name " + f"FOR (n:`{label}`) ON EACH " + f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" + ) + logger.info(f"Creating fulltext index named '{name}'") + driver.execute_query(query, {"name": name}) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j fulltext index creation failed {e}") + raise -def drop_index(driver: Driver, name: str) -> None: +def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None: """ This method constructs a Cypher query and executes it - to drop a vector index in Neo4j. + to drop an index in Neo4j, if the index exists. See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The name of the index to delete. + + Raises: + neo4j.exceptions.ClientError: If dropping of index fails. """ - query = "DROP INDEX $name IF EXISTS" - parameters = { - "name": name, - } - driver.execute_query(query, parameters) + try: + query = "DROP INDEX $name IF EXISTS" + parameters = { + "name": name, + } + logger.info(f"Dropping index named '{name}'") + driver.execute_query(query, parameters) + except neo4j.exceptions.ClientError as e: + logger.error(f"Dropping Neo4j index failed {e}") + raise diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index dc483eb62..c3ca671d8 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Any -from neo4j import Driver +import neo4j class Retriever(ABC): @@ -23,7 +23,7 @@ class Retriever(ABC): Abstract class for Neo4j retrievers """ - def __init__(self, driver: Driver): + def __init__(self, driver: neo4j.Driver): self.driver = driver self._verify_version() diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 0690555a2..a2163b343 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Record, Driver +import neo4j from pydantic import ValidationError from neo4j_genai.embedder import Embedder @@ -29,7 +29,7 @@ class HybridRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, embedder: Optional[Embedder] = None, @@ -46,7 +46,7 @@ def search( query_text: str, query_vector: Optional[list[float]] = None, top_k: int = 5, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -63,7 +63,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridSearchModel( @@ -96,7 +96,7 @@ def search( class HybridCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, retrieval_query: str, @@ -114,7 +114,7 @@ def search( query_vector: Optional[list[float]] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -132,7 +132,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridCypherSearchModel( diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 954cd04ed..af3a60685 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Driver, Record +import neo4j from neo4j_genai.retrievers.base import Retriever from pydantic import ValidationError @@ -39,7 +39,7 @@ class VectorRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, @@ -120,7 +120,7 @@ class VectorCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, retrieval_query: str, embedder: Optional[Embedder] = None, @@ -136,7 +136,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -154,7 +154,7 @@ def search( ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = VectorCypherSearchModel( diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 67a311752..fc9a3e5a3 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator -from neo4j import Driver +import neo4j class VectorSearchRecord(BaseModel): @@ -28,7 +28,7 @@ class IndexModel(BaseModel): @field_validator("driver") def check_driver_is_valid(cls, v): - if not isinstance(v, Driver): + if not isinstance(v, neo4j.Driver): raise ValueError("driver must be an instance of neo4j.Driver") return v diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 64cd65042..0ce0c2c87 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -19,7 +19,11 @@ import pytest from neo4j import GraphDatabase from neo4j_genai.embedder import Embedder -from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index +from neo4j_genai.indexes import ( + drop_index_if_exists, + create_vector_index, + create_fulltext_index, +) @pytest.fixture(scope="module") @@ -47,8 +51,8 @@ def setup_neo4j(driver): # Delete data and drop indexes to prevent data leakage driver.execute_query("MATCH (n) DETACH DELETE n") - drop_index(driver, vector_index_name) - drop_index(driver, fulltext_index_name) + drop_index_if_exists(driver, vector_index_name) + drop_index_if_exists(driver, fulltext_index_name) # Create a vector index create_vector_index( diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py index f8f54466e..3ba48c62a 100644 --- a/tests/e2e/test_hybrid_e2e.py +++ b/tests/e2e/test_hybrid_e2e.py @@ -16,7 +16,7 @@ import pytest -from neo4j import Record +import neo4j from neo4j_genai import ( HybridRetriever, @@ -36,7 +36,7 @@ def test_hybrid_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -58,7 +58,7 @@ def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -80,7 +80,7 @@ def test_hybrid_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -105,7 +105,7 @@ def test_hybrid_cypher_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -129,4 +129,4 @@ def test_hybrid_retriever_return_properties(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b22e58fc6..75e0419f8 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,14 +14,14 @@ # limitations under the License. import pytest +import neo4j from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever -from neo4j import Driver from unittest.mock import MagicMock, patch @pytest.fixture(scope="function") def driver(): - return MagicMock(spec=Driver) + return MagicMock(spec=neo4j.Driver) @pytest.fixture(scope="function") diff --git a/tests/unit/test_indexes.py b/tests/unit/test_indexes.py index 841226845..c5509da9a 100644 --- a/tests/unit/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import neo4j.exceptions import pytest from neo4j_genai.indexes import ( create_vector_index, - drop_index, + drop_index_if_exists, create_fulltext_index, ) @@ -68,16 +68,33 @@ def test_create_vector_index_validation_error_dimensions(driver): assert "Error for inputs to create_vector_index" in str(excinfo) +def test_create_vector_index_raises_error_with_neo4j_client_error(driver): + driver.execute_query.side_effect = neo4j.exceptions.ClientError + with pytest.raises(neo4j.exceptions.ClientError): + create_vector_index(driver, "my-index", "People", "name", 2048, "cosine") + + def test_create_vector_index_validation_error_similarity_fn(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", 1536, "algebra") assert "Error for inputs to create_vector_index" in str(excinfo) -def test_drop_index(driver): +def test_drop_index_if_exists(driver): drop_query = "DROP INDEX $name IF EXISTS" - drop_index(driver, "my-index") + drop_index_if_exists(driver, "my-index") + + driver.execute_query.assert_called_once_with( + drop_query, + {"name": "my-index"}, + ) + + +def test_drop_index_if_exists_raises_error_with_neo4j_client_error(driver): + drop_query = "DROP INDEX $name IF EXISTS" + + drop_index_if_exists(driver, "my-index") driver.execute_query.assert_called_once_with( drop_query, @@ -102,6 +119,15 @@ def test_create_fulltext_index_happy_path(driver): ) +def test_create_fulltext_index_raises_error_with_neo4j_client_error(driver): + label = "node-label" + text_node_properties = ["property-1", "property-2"] + driver.execute_query.side_effect = neo4j.exceptions.ClientError + + with pytest.raises(neo4j.exceptions.ClientError): + create_fulltext_index(driver, "my-index", label, text_node_properties) + + def test_create_fulltext_index_empty_node_properties(driver): label = "node-label" node_properties = []