Skip to content

Commit

Permalink
Custom Cypher GraphRAG class
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 18, 2024
1 parent 7e4ee29 commit 77b207d
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 24 deletions.
3 changes: 1 addition & 2 deletions examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import List
from neo4j import GraphDatabase
from neo4j_genai import VectorRetriever

Expand All @@ -18,7 +17,7 @@

# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .retrievers import VectorRetriever
from .retrievers import VectorRetriever, VectorCypherRetriever


__all__ = ["VectorRetriever"]
__all__ = ["VectorRetriever", "VectorCypherRetriever"]
6 changes: 2 additions & 4 deletions src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from neo4j import Driver
from pydantic import ValidationError
from .types import VectorIndexModel, FulltextIndexModel
Expand Down Expand Up @@ -55,7 +53,7 @@ def create_vector_index(


def create_fulltext_index(
driver: Driver, name: str, label: str, node_properties: List[str]
driver: Driver, name: str, label: str, node_properties: list[str]
) -> None:
"""
This method constructs a Cypher query and executes it
Expand All @@ -67,7 +65,7 @@ def create_fulltext_index(
driver (Driver): Neo4j Python driver instance.
name (str): The unique name of the index.
label (str): The node label to be indexed.
node_properties (List[str]): The node properties to create the fulltext index on.
node_properties (list[str]): The node properties to create the fulltext index on.
Raises:
ValueError: If validation of the input arguments fail.
Expand Down
87 changes: 81 additions & 6 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
Expand Down Expand Up @@ -50,19 +50,18 @@ def _verify_version(self) -> None:

def search(
self,
query_vector: Optional[List[float]] = None,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
) -> list[Neo4jRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
name (str): Refers to the unique name of the vector index to query.
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Expand All @@ -71,7 +70,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -109,3 +108,79 @@ def search(
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)


class VectorCypherRetriever(VectorRetriever):
"""
Provides retrieval method using vector similarity and custom Cypher query
"""

def __init__(
self,
driver: Driver,
index_name: str,
custom_retrieval_query: str,
custom_query_params: Optional[dict[str, Any]] = None,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
self._verify_version()
self.index_name = index_name
self.custom_retrieval_query = custom_retrieval_query
self.custom_query_params = custom_query_params
self.embedder = embedder

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Neo4jRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
Any: The results of the search query
"""
try:
validated_data = SimilaritySearchModel(
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

if self.custom_query_params:
for key, value in self.custom_query_params.items():
if key not in parameters:
parameters[key] = value

query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
search_query = query_prefix + self.custom_retrieval_query
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
8 changes: 4 additions & 4 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any, Literal, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator, field_validator
from neo4j import Driver

Expand All @@ -9,7 +9,7 @@ class Neo4jRecord(BaseModel):


class EmbeddingVector(BaseModel):
vector: List[float]
vector: list[float]


class IndexModel(BaseModel):
Expand All @@ -33,7 +33,7 @@ class VectorIndexModel(IndexModel):
class FulltextIndexModel(IndexModel):
name: str
label: str
node_properties: List[str]
node_properties: list[str]

@field_validator("node_properties")
def check_node_properties_not_empty(cls, v):
Expand All @@ -45,7 +45,7 @@ def check_node_properties_not_empty(cls, v):
class SimilaritySearchModel(BaseModel):
index_name: str
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_vector: Optional[list[float]] = None
query_text: Optional[str] = None

@model_validator(mode="before")
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from neo4j_genai import VectorRetriever
from neo4j_genai import VectorRetriever, VectorCypherRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -11,5 +11,14 @@ def driver():

@pytest.fixture
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
def vector_retriever(_verify_version_mock, driver):
return VectorRetriever(driver, "my-index")


@pytest.fixture
@patch("neo4j_genai.VectorCypherRetriever._verify_version")
def vector_cypher_retriever(_verify_version_mock, driver):
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
return VectorCypherRetriever(driver, "my-index", custom_retrieval_query)
Loading

0 comments on commit 77b207d

Please sign in to comment.