Skip to content

Commit

Permalink
Add try catch for create_index and rename imports of neo4j
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed May 9, 2024
1 parent 091714f commit b2b3571
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 57 deletions.
79 changes: 53 additions & 26 deletions src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from neo4j import Driver
import neo4j
from pydantic import ValidationError
from .types import VectorIndexModel, FulltextIndexModel
import logging


logger = logging.getLogger(__name__)


def create_vector_index(
driver: Driver,
driver: neo4j.Driver,
name: str,
label: str,
property: str,
Expand All @@ -32,8 +36,11 @@ def create_vector_index(
See Cypher manual on [Create vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#indexes-vector-create)
Important: This operation will fail if an index with the same name already exists.
Ensure that the index name provided is unique within the database context.
Args:
driver (Driver): Neo4j Python driver instance.
driver (neo4j.Driver): Neo4j Python driver instance.
name (str): The unique name of the index.
label (str): The node label to be indexed.
property (str): The property key of a node which contains embedding values.
Expand All @@ -43,6 +50,7 @@ def create_vector_index(
Raises:
ValueError: If validation of the input arguments fail.
neo4j.exceptions.ClientError: If creation of vector index fails.
"""
try:
VectorIndexModel(
Expand All @@ -58,26 +66,35 @@ def create_vector_index(
except ValidationError as e:
raise ValueError(f"Error for inputs to create_vector_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
driver.execute_query(
query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}
)
try:
query = (
f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
logger.info(f"Creating vector index named '{name}'")
driver.execute_query(
query,
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
)
except neo4j.exceptions.ClientError as e:
logger.error(f"Neo4j vector index creation failed {e}")
raise


def create_fulltext_index(
driver: Driver, name: str, label: str, node_properties: list[str]
driver: neo4j.Driver, name: str, label: str, node_properties: list[str]
) -> None:
"""
This method constructs a Cypher query and executes it
to create a new fulltext index in Neo4j.
See Cypher manual on [Create fulltext index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/full-text-indexes/#create-full-text-indexes)
Important: This operation will fail if an index with the same name already exists.
Ensure that the index name provided is unique within the database context.
Args:
driver (Driver): Neo4j Python driver instance.
driver (neo4j.Driver): Neo4j Python driver instance.
name (str): The unique name of the index.
label (str): The node label to be indexed.
node_properties (list[str]): The node properties to create the fulltext index on.
Expand All @@ -97,26 +114,36 @@ def create_fulltext_index(
except ValidationError as e:
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")

query = (
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
driver.execute_query(query, {"name": name})
try:
query = (
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
logger.info(f"Creating fulltext index named '{name}'")
driver.execute_query(query, {"name": name})
except neo4j.exceptions.ClientError as e:
logger.error(f"Neo4j fulltext index creation failed {e}")
raise


def drop_index(driver: Driver, name: str) -> None:
def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
"""
This method constructs a Cypher query and executes it
to drop a vector index in Neo4j.
to drop a vector index in Neo4j, if the index exists.
See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop)
Args:
driver (Driver): Neo4j Python driver instance.
driver (neo4j.Driver): Neo4j Python driver instance.
name (str): The name of the index to delete.
"""
query = "DROP INDEX $name IF EXISTS"
parameters = {
"name": name,
}
driver.execute_query(query, parameters)
try:
query = "DROP INDEX $name IF EXISTS"
parameters = {
"name": name,
}
logger.info(f"Dropping index named '{name}'")
driver.execute_query(query, parameters)
except neo4j.exceptions.ClientError as e:
logger.error(f"Neo4j fulltext index creation failed {e}")
raise
4 changes: 2 additions & 2 deletions src/neo4j_genai/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from abc import ABC, abstractmethod
from typing import Any

from neo4j import Driver
import neo4j


class Retriever(ABC):
"""
Abstract class for Neo4j retrievers
"""

def __init__(self, driver: Driver):
def __init__(self, driver: neo4j.Driver):
self.driver = driver
self._verify_version()

Expand Down
14 changes: 7 additions & 7 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from typing import Optional, Any

from neo4j import Record, Driver
import neo4j
from pydantic import ValidationError

from neo4j_genai.embedder import Embedder
Expand All @@ -29,7 +29,7 @@
class HybridRetriever(Retriever):
def __init__(
self,
driver: Driver,
driver: neo4j.Driver,
vector_index_name: str,
fulltext_index_name: str,
embedder: Optional[Embedder] = None,
Expand All @@ -46,7 +46,7 @@ def search(
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
) -> list[Record]:
) -> list[neo4j.Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
If query_vector is provided, then it will be preferred over the embedded query_text
Expand All @@ -63,7 +63,7 @@ def search(
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
list[neo4j.Record]: The results of the search query
"""
try:
validated_data = HybridSearchModel(
Expand Down Expand Up @@ -96,7 +96,7 @@ def search(
class HybridCypherRetriever(Retriever):
def __init__(
self,
driver: Driver,
driver: neo4j.Driver,
vector_index_name: str,
fulltext_index_name: str,
retrieval_query: str,
Expand All @@ -114,7 +114,7 @@ def search(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Record]:
) -> list[neo4j.Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
If query_vector is provided, then it will be preferred over the embedded query_text
Expand All @@ -132,7 +132,7 @@ def search(
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
list[neo4j.Record]: The results of the search query
"""
try:
validated_data = HybridCypherSearchModel(
Expand Down
10 changes: 5 additions & 5 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from typing import Optional, Any

from neo4j import Driver, Record
import neo4j
from neo4j_genai.retrievers.base import Retriever
from pydantic import ValidationError

Expand All @@ -39,7 +39,7 @@ class VectorRetriever(Retriever):

def __init__(
self,
driver: Driver,
driver: neo4j.Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
Expand Down Expand Up @@ -120,7 +120,7 @@ class VectorCypherRetriever(Retriever):

def __init__(
self,
driver: Driver,
driver: neo4j.Driver,
index_name: str,
retrieval_query: str,
embedder: Optional[Embedder] = None,
Expand All @@ -136,7 +136,7 @@ def search(
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Record]:
) -> list[neo4j.Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -154,7 +154,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
list[neo4j.Record]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator, field_validator
from neo4j import Driver
import neo4j


class VectorSearchRecord(BaseModel):
Expand All @@ -28,7 +28,7 @@ class IndexModel(BaseModel):

@field_validator("driver")
def check_driver_is_valid(cls, v):
if not isinstance(v, Driver):
if not isinstance(v, neo4j.Driver):
raise ValueError("driver must be an instance of neo4j.Driver")
return v

Expand Down
10 changes: 7 additions & 3 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import pytest
from neo4j import GraphDatabase
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index
from neo4j_genai.indexes import (
drop_index_if_exists,
create_vector_index,
create_fulltext_index,
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -47,8 +51,8 @@ def setup_neo4j(driver):

# Delete data and drop indexes to prevent data leakage
driver.execute_query("MATCH (n) DETACH DELETE n")
drop_index(driver, vector_index_name)
drop_index(driver, fulltext_index_name)
drop_index_if_exists(driver, vector_index_name)
drop_index_if_exists(driver, fulltext_index_name)

# Create a vector index
create_vector_index(
Expand Down
12 changes: 6 additions & 6 deletions tests/e2e/test_hybrid_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

from neo4j import Record
import neo4j

from neo4j_genai import (
HybridRetriever,
Expand All @@ -36,7 +36,7 @@ def test_hybrid_retriever_search_text(driver, custom_embedder):
assert isinstance(results, list)
assert len(results) == 5
for result in results:
assert isinstance(result, Record)
assert isinstance(result, neo4j.Record)


@pytest.mark.usefixtures("setup_neo4j")
Expand All @@ -58,7 +58,7 @@ def test_hybrid_cypher_retriever_search_text(driver, custom_embedder):
assert isinstance(results, list)
assert len(results) == 5
for record in results:
assert isinstance(record, Record)
assert isinstance(record, neo4j.Record)
assert "author.name" in record.keys()


Expand All @@ -80,7 +80,7 @@ def test_hybrid_retriever_search_vector(driver):
assert isinstance(results, list)
assert len(results) == 5
for result in results:
assert isinstance(result, Record)
assert isinstance(result, neo4j.Record)


@pytest.mark.usefixtures("setup_neo4j")
Expand All @@ -105,7 +105,7 @@ def test_hybrid_cypher_retriever_search_vector(driver):
assert isinstance(results, list)
assert len(results) == 5
for record in results:
assert isinstance(record, Record)
assert isinstance(record, neo4j.Record)
assert "author.name" in record.keys()


Expand All @@ -129,4 +129,4 @@ def test_hybrid_retriever_return_properties(driver):
assert isinstance(results, list)
assert len(results) == 5
for result in results:
assert isinstance(result, Record)
assert isinstance(result, neo4j.Record)
4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.

import pytest
import neo4j
from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch


@pytest.fixture(scope="function")
def driver():
return MagicMock(spec=Driver)
return MagicMock(spec=neo4j.Driver)


@pytest.fixture(scope="function")
Expand Down
Loading

0 comments on commit b2b3571

Please sign in to comment.