diff --git a/examples/hybrid_cypher_search.py b/examples/hybrid_cypher_search.py new file mode 100644 index 000000000..a121b2eb4 --- /dev/null +++ b/examples/hybrid_cypher_search.py @@ -0,0 +1,62 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai import HybridCypherRetriever +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, create_fulltext_index + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name" +retriever = HybridCypherRetriever( + driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder +) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "Who are the fremen?" +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index d9403c49a..bade3d75b 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -14,7 +14,11 @@ # limitations under the License. from .retrievers.vector import VectorRetriever, VectorCypherRetriever -from .retrievers.hybrid import HybridRetriever +from .retrievers.hybrid import HybridRetriever, HybridCypherRetriever - -__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"] +__all__ = [ + "VectorRetriever", + "VectorCypherRetriever", + "HybridRetriever", + "HybridCypherRetriever", +] diff --git a/src/neo4j_genai/queries.py b/src/neo4j_genai/queries.py index 4505201d4..752677d5f 100644 --- a/src/neo4j_genai/queries.py +++ b/src/neo4j_genai/queries.py @@ -17,12 +17,16 @@ from neo4j_genai.types import SearchType -def get_search_query(search_type: SearchType, return_properties: Optional[list[str]] = None,): +def get_search_query( + search_type: SearchType, + return_properties: Optional[list[str]] = None, + retrieval_query: Optional[str] = None, +): query_map = { - SearchType.Vector: ( - "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score " + SearchType.VECTOR: ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " ), - SearchType.Hybrid: ( + SearchType.HYBRID: ( "CALL { " "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " @@ -37,14 +41,16 @@ def get_search_query(search_type: SearchType, return_properties: Optional[list[s ), } - search_query = query_map[search_type] + base_query = query_map[search_type] + additional_query = "" - if return_properties: - return_properties_cypher = ", ".join( - [f".{prop}" for prop in return_properties] - ) - search_query += "YIELD node, score " - search_query += f"RETURN node {{{return_properties_cypher}}} as node, score" + if retrieval_query: + additional_query += retrieval_query + elif return_properties: + return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) + additional_query += "YIELD node, score " + additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" else: - search_query += "RETURN node, score" - return search_query + additional_query += "RETURN node, score" + + return base_query + additional_query diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 5dc1f5732..fe3827b2a 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any from neo4j import Record, Driver from pydantic import ValidationError from neo4j_genai.embedder import Embedder from neo4j_genai.retrievers.base import Retriever -from neo4j_genai.types import HybridModel, SearchType +from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel from neo4j_genai.queries import get_search_query @@ -64,7 +64,7 @@ def search( list[Record]: The results of the search query """ try: - validated_data = HybridModel( + validated_data = HybridSearchModel( vector_index_name=self.vector_index_name, fulltext_index_name=self.fulltext_index_name, top_k=top_k, @@ -82,7 +82,82 @@ def search( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query = get_search_query(SearchType.Hybrid, self.return_properties) + search_query = get_search_query(SearchType.HYBRID, self.return_properties) records, _, _ = self.driver.execute_query(search_query, parameters) return records + + +class HybridCypherRetriever(Retriever): + def __init__( + self, + driver: Driver, + vector_index_name: str, + fulltext_index_name: str, + retrieval_query: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.vector_index_name = vector_index_name + self.fulltext_index_name = fulltext_index_name + self.retrieval_query = retrieval_query + self.embedder = embedder + + def search( + self, + query_text: str, + query_vector: Optional[list[float]] = None, + top_k: int = 5, + query_params: Optional[dict[str, Any]] = None, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + Both query_vector and query_text can be provided. + If query_vector is provided, then it will be preferred over the embedded query_text + for the vector search. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) + Args: + query_text (str): The text to get the closest neighbors of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridCypherSearchModel( + vector_index_name=self.vector_index_name, + fulltext_index_name=self.fulltext_index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + query_params=query_params, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text and not query_vector: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + query_vector = self.embedder.embed_query(query_text) + parameters["query_vector"] = query_vector + + if query_params: + for key, value in query_params.items(): + if key not in parameters: + parameters[key] = value + del parameters["query_params"] + + search_query = get_search_query( + SearchType.HYBRID, retrieval_query=self.retrieval_query + ) + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 3a05603ca..1b404643b 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -21,7 +21,7 @@ from neo4j_genai.embedder import Embedder from neo4j_genai.types import ( VectorSearchRecord, - SimilaritySearchModel, + VectorSearchModel, VectorCypherSearchModel, SearchType, ) @@ -72,7 +72,7 @@ def search( list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: - validated_data = SimilaritySearchModel( + validated_data = VectorSearchModel( index_name=self.index_name, top_k=top_k, query_vector=query_vector, @@ -91,7 +91,7 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.Vector, self.return_properties) + search_query = get_search_query(SearchType.VECTOR, self.return_properties) records, _, _ = self.driver.execute_query(search_query, parameters) @@ -177,6 +177,8 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query(SearchType.Vector) + self.retrieval_query + search_query = get_search_query( + SearchType.VECTOR, retrieval_query=self.retrieval_query + ) records, _, _ = self.driver.execute_query(search_query, parameters) return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index fb3aa5d83..67a311752 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -53,7 +53,7 @@ def check_node_properties_not_empty(cls, v): return v -class SimilaritySearchModel(BaseModel): +class VectorSearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None @@ -72,11 +72,11 @@ def check_query(cls, values): return values -class VectorCypherSearchModel(SimilaritySearchModel): +class VectorCypherSearchModel(VectorSearchModel): query_params: Optional[dict[str, Any]] = None -class HybridModel(BaseModel): +class HybridSearchModel(BaseModel): vector_index_name: str fulltext_index_name: str query_text: str @@ -84,8 +84,12 @@ class HybridModel(BaseModel): query_vector: Optional[list[float]] = None +class HybridCypherSearchModel(HybridSearchModel): + query_params: Optional[dict[str, Any]] = None + + class SearchType(str, Enum): """Enumerator of the search strategies.""" - Vector = "vector" - Hybrid = "hybrid" + VECTOR = "vector" + HYBRID = "hybrid" diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 000000000..37e5d4260 --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,75 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neo4j_genai.queries import get_search_query +from neo4j_genai.types import SearchType + + +def test_vector_search_basic(): + expected = ( + "CALL db.index.vector.queryNodes($index, $k, $embedding) " "RETURN node, score" + ) + result = get_search_query(SearchType.VECTOR) + assert result == expected + + +def test_hybrid_search_basic(): + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + result = get_search_query(SearchType.HYBRID) + assert result == expected + + +def test_vector_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL db.index.vector.queryNodes($index, $k, $embedding) " + "YIELD node, score " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.VECTOR, return_properties=properties) + assert result == expected + + +def test_hybrid_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + + retrieval_query + ) + result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + assert result == expected diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 7c1a07abb..e6346334f 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -19,6 +19,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever +from neo4j_genai.retrievers.hybrid import HybridCypherRetriever from neo4j_genai.types import VectorSearchRecord, SearchType from neo4j_genai.queries import get_search_query @@ -65,7 +66,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector) + search_query = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) @@ -94,7 +95,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector) + search_query = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) @@ -130,7 +131,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector, return_properties) + search_query = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -206,7 +207,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector) + search_query = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -240,7 +241,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector) + search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -249,7 +250,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): custom_embeddings.embed_query.assert_called_once_with(query_text) driver.execute_query.assert_called_once_with( - search_query + retrieval_query, + search_query, { "index_name": index_name, "top_k": top_k, @@ -284,7 +285,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Vector) + search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -295,7 +296,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): custom_embeddings.embed_query.assert_called_once_with(query_text) driver.execute_query.assert_called_once_with( - search_query + retrieval_query, + search_query, { "index_name": index_name, "top_k": top_k, @@ -347,7 +348,7 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Hybrid) + search_query = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -385,7 +386,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( None, None, ] - search_query = get_search_query(SearchType.Hybrid) + search_query = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -448,7 +449,7 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.Hybrid, return_properties) + search_query = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -464,3 +465,55 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): }, ) assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridCypherRetriever._verify_version") +def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + query_params = { + "param": "dummy-param", + } + retriever = HybridCypherRetriever( + driver, + vector_index_name, + fulltext_index_name, + retrieval_query, + custom_embeddings, + ) + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + + records = retriever.search( + query_text=query_text, + top_k=top_k, + query_params=query_params, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + "param": "dummy-param", + }, + ) + + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]