Skip to content

Commit

Permalink
Custom Cypher GraphRAG class
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 11, 2024
1 parent 7e4ee29 commit 8aa12df
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .retrievers import VectorRetriever
from .retrievers import VectorRetriever, GraphRetriever


__all__ = ["VectorRetriever"]
__all__ = ["VectorRetriever", "GraphRetriever"]
81 changes: 79 additions & 2 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Dict, Any
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
Expand Down Expand Up @@ -61,7 +61,6 @@ def search(
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
name (str): Refers to the unique name of the vector index to query.
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Expand Down Expand Up @@ -109,3 +108,81 @@ def search(
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)


class GraphRetriever(VectorRetriever):
"""
Provides retrieval method using vector similarity and custom Cypher query
"""

def __init__(
self,
driver: Driver,
index_name: str,
custom_retrieval_query: str,
custom_query_params: Optional[Dict[str, Any]] = None,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
self._verify_version()
self.index_name = index_name
self.custom_retrieval_query = custom_retrieval_query
self.custom_query_params = custom_query_params
self.embedder = embedder

def search(
self,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
name (str): Refers to the unique name of the vector index to query.
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_params (Optional[Dict[str, Any]], optional: Query parameters to provide for the custom query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
Any: The results of the search query
"""
try:
validated_data = SimilaritySearchModel(
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

if self.custom_query_params:
for key, value in self.custom_query_params.items():
if key not in parameters:
parameters[key] = value

query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
search_query = query_prefix + self.custom_retrieval_query
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ def driver():

@pytest.fixture
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
def vector_retriever(_verify_version_mock, driver):
return VectorRetriever(driver, "my-index")
129 changes: 125 additions & 4 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pytest
from unittest.mock import patch, MagicMock

from neo4j.exceptions import CypherSyntaxError

# from neo4j.exceptions import CypherSyntaxError

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import GraphRetriever
from neo4j_genai.types import Neo4jRecord


Expand Down Expand Up @@ -110,23 +116,23 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


def test_similarity_search_missing_embedder_for_text(retriever):
def test_similarity_search_missing_embedder_for_text(vector_retriever):
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
retriever.search(query_text=query_text, top_k=top_k)
vector_retriever.search(query_text=query_text, top_k=top_k)


def test_similarity_search_both_text_and_vector(retriever):
def test_similarity_search_both_text_and_vector(vector_retriever):
query_text = "may thy knife chip and shatter"
query_vector = [1.1, 2.2, 3.3]
top_k = 5

with pytest.raises(
ValueError, match="You must provide exactly one of query_vector or query_text."
):
retriever.search(
vector_retriever.search(
query_text=query_text,
query_vector=query_vector,
top_k=top_k,
Expand Down Expand Up @@ -167,3 +173,118 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"query_vector": query_vector,
},
)


@patch("neo4j_genai.GraphRetriever._verify_version")
def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
retriever = GraphRetriever(
driver, index_name, custom_retrieval_query, embedder=custom_embeddings
)
query_text = "may thy knife chip and shatter"
top_k = 5
driver.execute_query.return_value = [
[{"node_id": 123, "text": "dummy-text", "score": 1.0}],
None,
None,
]
search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""

records = retriever.search(
query_text=query_text,
top_k=top_k,
)

custom_embeddings.embed_query.assert_called_once_with(query_text)
driver.execute_query.assert_called_once_with(
search_query + custom_retrieval_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": embed_query_vector,
},
)
assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GraphRetriever._verify_version")
def test_custom_retrieval_query_with_params(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata
"""
custom_params = {
"param": "dummy-param",
}
retriever = GraphRetriever(
driver,
index_name,
custom_retrieval_query,
custom_params,
embedder=custom_embeddings,
)
query_text = "may thy knife chip and shatter"
top_k = 5
driver.execute_query.return_value = [
[{"node_id": 123, "text": "dummy-text", "score": 1.0}],
None,
None,
]
search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""

records = retriever.search(
query_text=query_text,
top_k=top_k,
)

custom_embeddings.embed_query.assert_called_once_with(query_text)

driver.execute_query.assert_called_once_with(
search_query + custom_retrieval_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": embed_query_vector,
"param": "dummy-param",
},
)

assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GraphRetriever._verify_version")
def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
custom_retrieval_query = """
this is not a cypher query
"""
retriever = GraphRetriever(
driver, index_name, custom_retrieval_query, embedder=custom_embeddings
)
query_text = "may thy knife chip and shatter"
top_k = 5
driver.execute_query.side_effect = CypherSyntaxError

with pytest.raises(CypherSyntaxError):
retriever.search(
query_text=query_text,
top_k=top_k,
)

0 comments on commit 8aa12df

Please sign in to comment.