From 77b207d1f0b24acf35bb5321b560f14a8b228497 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 11 Apr 2024 17:43:51 +0100 Subject: [PATCH] Custom Cypher GraphRAG class --- examples/similarity_search_for_text.py | 3 +- src/neo4j_genai/__init__.py | 4 +- src/neo4j_genai/indexes.py | 6 +- src/neo4j_genai/retrievers.py | 87 +++++++++++++- src/neo4j_genai/types.py | 8 +- tests/conftest.py | 13 ++- tests/test_retrievers.py | 152 ++++++++++++++++++++++++- 7 files changed, 249 insertions(+), 24 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index ad28ad0e8..d28aee115 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,4 +1,3 @@ -from typing import List from neo4j import GraphDatabase from neo4j_genai import VectorRetriever @@ -18,7 +17,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index de6038a8e..cd184d237 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -1,4 +1,4 @@ -from .retrievers import VectorRetriever +from .retrievers import VectorRetriever, VectorCypherRetriever -__all__ = ["VectorRetriever"] +__all__ = ["VectorRetriever", "VectorCypherRetriever"] diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 29301d61b..5e2ede808 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -1,5 +1,3 @@ -from typing import List - from neo4j import Driver from pydantic import ValidationError from .types import VectorIndexModel, FulltextIndexModel @@ -55,7 +53,7 @@ def create_vector_index( def create_fulltext_index( - driver: Driver, name: str, label: str, node_properties: List[str] + driver: Driver, name: str, label: str, node_properties: list[str] ) -> None: """ This method constructs a Cypher query and executes it @@ -67,7 +65,7 @@ def create_fulltext_index( driver (Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. - node_properties (List[str]): The node properties to create the fulltext index on. + node_properties (list[str]): The node properties to create the fulltext index on. Raises: ValueError: If validation of the input arguments fail. diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index ef1ab84db..3b494b115 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional, Any from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder @@ -50,10 +50,10 @@ def _verify_version(self) -> None: def search( self, - query_vector: Optional[List[float]] = None, + query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> List[Neo4jRecord]: + ) -> list[Neo4jRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -61,8 +61,7 @@ def search( - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) Args: - name (str): Refers to the unique name of the vector index to query. - query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. @@ -71,7 +70,7 @@ def search( ValueError: If no embedder is provided. Returns: - List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. + list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: validated_data = SimilaritySearchModel( @@ -109,3 +108,79 @@ def search( raise ValueError( f"Validation failed while constructing output: {error_details}" ) + + +class VectorCypherRetriever(VectorRetriever): + """ + Provides retrieval method using vector similarity and custom Cypher query + """ + + def __init__( + self, + driver: Driver, + index_name: str, + custom_retrieval_query: str, + custom_query_params: Optional[dict[str, Any]] = None, + embedder: Optional[Embedder] = None, + ) -> None: + self.driver = driver + self._verify_version() + self.index_name = index_name + self.custom_retrieval_query = custom_retrieval_query + self.custom_query_params = custom_query_params + self.embedder = embedder + + def search( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + ) -> list[Neo4jRecord]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + + Returns: + Any: The results of the search query + """ + try: + validated_data = SimilaritySearchModel( + index_name=self.index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + if self.custom_query_params: + for key, value in self.custom_query_params.items(): + if key not in parameters: + parameters[key] = value + + query_prefix = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + search_query = query_prefix + self.custom_retrieval_query + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 91db6db74..60ac63553 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,4 +1,4 @@ -from typing import List, Any, Literal, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator from neo4j import Driver @@ -9,7 +9,7 @@ class Neo4jRecord(BaseModel): class EmbeddingVector(BaseModel): - vector: List[float] + vector: list[float] class IndexModel(BaseModel): @@ -33,7 +33,7 @@ class VectorIndexModel(IndexModel): class FulltextIndexModel(IndexModel): name: str label: str - node_properties: List[str] + node_properties: list[str] @field_validator("node_properties") def check_node_properties_not_empty(cls, v): @@ -45,7 +45,7 @@ def check_node_properties_not_empty(cls, v): class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 - query_vector: Optional[List[float]] = None + query_vector: Optional[list[float]] = None query_text: Optional[str] = None @model_validator(mode="before") diff --git a/tests/conftest.py b/tests/conftest.py index d05db4bde..b54b44210 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from neo4j_genai import VectorRetriever +from neo4j_genai import VectorRetriever, VectorCypherRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -11,5 +11,14 @@ def driver(): @pytest.fixture @patch("neo4j_genai.VectorRetriever._verify_version") -def retriever(_verify_version_mock, driver): +def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") + + +@pytest.fixture +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def vector_cypher_retriever(_verify_version_mock, driver): + custom_retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score + """ + return VectorCypherRetriever(driver, "my-index", custom_retrieval_query) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index cb8014e7b..06cd16aef 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -1,6 +1,10 @@ import pytest from unittest.mock import patch, MagicMock + +from neo4j.exceptions import CypherSyntaxError + from neo4j_genai import VectorRetriever +from neo4j_genai.retrievers import VectorCypherRetriever from neo4j_genai.types import Neo4jRecord @@ -110,15 +114,15 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [Neo4jRecord(node="dummy-node", score=1.0)] -def test_similarity_search_missing_embedder_for_text(retriever): +def test_vector_retriever_search_missing_embedder_for_text(vector_retriever): query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - retriever.search(query_text=query_text, top_k=top_k) + vector_retriever.search(query_text=query_text, top_k=top_k) -def test_similarity_search_both_text_and_vector(retriever): +def test_vector_retriever_search_both_text_and_vector(vector_retriever): query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -126,7 +130,32 @@ def test_similarity_search_both_text_and_vector(retriever): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - retriever.search( + vector_retriever.search( + query_text=query_text, + query_vector=query_vector, + top_k=top_k, + ) + + +def test_vector_cypher_retriever_search_missing_embedder_for_text( + vector_cypher_retriever, +): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + vector_cypher_retriever.search(query_text=query_text, top_k=top_k) + + +def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retriever): + query_text = "may thy knife chip and shatter" + query_vector = [1.1, 2.2, 3.3] + top_k = 5 + + with pytest.raises( + ValueError, match="You must provide exactly one of query_vector or query_text." + ): + vector_cypher_retriever.search( query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -167,3 +196,118 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): "query_vector": query_vector, }, ) + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score + """ + retriever = VectorCypherRetriever( + driver, index_name, custom_retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + driver.execute_query.assert_called_once_with( + search_query + custom_retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + }, + ) + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_custom_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + custom_params = { + "param": "dummy-param", + } + retriever = VectorCypherRetriever( + driver, + index_name, + custom_retrieval_query, + custom_params, + embedder=custom_embeddings, + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query + custom_retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + "param": "dummy-param", + }, + ) + + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + custom_retrieval_query = """ + this is not a cypher query + """ + retriever = VectorCypherRetriever( + driver, index_name, custom_retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.side_effect = CypherSyntaxError + + with pytest.raises(CypherSyntaxError): + retriever.search( + query_text=query_text, + top_k=top_k, + )