From 8aa12df044b861f845d5e14707d85cbdc7855ed2 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 11 Apr 2024 17:43:51 +0100 Subject: [PATCH] Custom Cypher GraphRAG class --- src/neo4j_genai/__init__.py | 4 +- src/neo4j_genai/retrievers.py | 81 ++++++++++++++++++++- tests/conftest.py | 2 +- tests/test_retrievers.py | 129 ++++++++++++++++++++++++++++++++-- 4 files changed, 207 insertions(+), 9 deletions(-) diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index de6038a8e..5977ff3c9 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -1,4 +1,4 @@ -from .retrievers import VectorRetriever +from .retrievers import VectorRetriever, GraphRetriever -__all__ = ["VectorRetriever"] +__all__ = ["VectorRetriever", "GraphRetriever"] diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index ef1ab84db..28ace13b7 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Dict, Any from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder @@ -61,7 +61,6 @@ def search( - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) Args: - name (str): Refers to the unique name of the vector index to query. query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. @@ -109,3 +108,81 @@ def search( raise ValueError( f"Validation failed while constructing output: {error_details}" ) + + +class GraphRetriever(VectorRetriever): + """ + Provides retrieval method using vector similarity and custom Cypher query + """ + + def __init__( + self, + driver: Driver, + index_name: str, + custom_retrieval_query: str, + custom_query_params: Optional[Dict[str, Any]] = None, + embedder: Optional[Embedder] = None, + ) -> None: + self.driver = driver + self._verify_version() + self.index_name = index_name + self.custom_retrieval_query = custom_retrieval_query + self.custom_query_params = custom_query_params + self.embedder = embedder + + def search( + self, + query_vector: Optional[List[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + ) -> List[Neo4jRecord]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + name (str): Refers to the unique name of the vector index to query. + query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + custom_params (Optional[Dict[str, Any]], optional: Query parameters to provide for the custom query. Defaults to None + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + + Returns: + Any: The results of the search query + """ + try: + validated_data = SimilaritySearchModel( + index_name=self.index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + if self.custom_query_params: + for key, value in self.custom_query_params.items(): + if key not in parameters: + parameters[key] = value + + query_prefix = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + search_query = query_prefix + self.custom_retrieval_query + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/tests/conftest.py b/tests/conftest.py index d05db4bde..a77c44484 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,5 +11,5 @@ def driver(): @pytest.fixture @patch("neo4j_genai.VectorRetriever._verify_version") -def retriever(_verify_version_mock, driver): +def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index cb8014e7b..435779c6f 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -1,6 +1,12 @@ import pytest from unittest.mock import patch, MagicMock + +from neo4j.exceptions import CypherSyntaxError + +# from neo4j.exceptions import CypherSyntaxError + from neo4j_genai import VectorRetriever +from neo4j_genai.retrievers import GraphRetriever from neo4j_genai.types import Neo4jRecord @@ -110,15 +116,15 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [Neo4jRecord(node="dummy-node", score=1.0)] -def test_similarity_search_missing_embedder_for_text(retriever): +def test_similarity_search_missing_embedder_for_text(vector_retriever): query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - retriever.search(query_text=query_text, top_k=top_k) + vector_retriever.search(query_text=query_text, top_k=top_k) -def test_similarity_search_both_text_and_vector(retriever): +def test_similarity_search_both_text_and_vector(vector_retriever): query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -126,7 +132,7 @@ def test_similarity_search_both_text_and_vector(retriever): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - retriever.search( + vector_retriever.search( query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -167,3 +173,118 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): "query_vector": query_vector, }, ) + + +@patch("neo4j_genai.GraphRetriever._verify_version") +def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score + """ + retriever = GraphRetriever( + driver, index_name, custom_retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + driver.execute_query.assert_called_once_with( + search_query + custom_retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + }, + ) + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.GraphRetriever._verify_version") +def test_custom_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + custom_params = { + "param": "dummy-param", + } + retriever = GraphRetriever( + driver, + index_name, + custom_retrieval_query, + custom_params, + embedder=custom_embeddings, + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query + custom_retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + "param": "dummy-param", + }, + ) + + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.GraphRetriever._verify_version") +def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + this is not a cypher query + """ + retriever = GraphRetriever( + driver, index_name, custom_retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.side_effect = CypherSyntaxError + + with pytest.raises(CypherSyntaxError): + retriever.search( + query_text=query_text, + top_k=top_k, + )