From 4e840d03b4641134cf8aa5ae13655c140f8e2fa1 Mon Sep 17 00:00:00 2001 From: willtai Date: Tue, 30 Apr 2024 10:39:02 +0100 Subject: [PATCH] Add HybridCypherRetriever, refactors Retriever directory, refactor query construction (#22) * Refactored retrievers into a directory with separate files * Adds HybridCypherRetriever, refactor query construction, add tests * Update README (#20) * Minor changes to the README --------- Co-authored-by: Oskar Hane --- examples/hybrid_cypher_search.py | 62 +++++++ examples/hybrid_search.py | 2 +- src/neo4j_genai/__init__.py | 11 +- src/neo4j_genai/neo4j_queries.py | 56 ++++++ src/neo4j_genai/retrievers/__init__.py | 0 src/neo4j_genai/retrievers/base.py | 59 +++++++ src/neo4j_genai/retrievers/hybrid.py | 161 +++++++++++++++++ .../{retrievers.py => retrievers/vector.py} | 166 ++---------------- src/neo4j_genai/types.py | 19 +- tests/test_queries.py | 76 ++++++++ tests/test_retrievers.py | 160 +++++++---------- 11 files changed, 519 insertions(+), 253 deletions(-) create mode 100644 examples/hybrid_cypher_search.py create mode 100644 src/neo4j_genai/neo4j_queries.py create mode 100644 src/neo4j_genai/retrievers/__init__.py create mode 100644 src/neo4j_genai/retrievers/base.py create mode 100644 src/neo4j_genai/retrievers/hybrid.py rename src/neo4j_genai/{retrievers.py => retrievers/vector.py} (52%) create mode 100644 tests/test_queries.py diff --git a/examples/hybrid_cypher_search.py b/examples/hybrid_cypher_search.py new file mode 100644 index 00000000..a121b2eb --- /dev/null +++ b/examples/hybrid_cypher_search.py @@ -0,0 +1,62 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai import HybridCypherRetriever +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, create_fulltext_index + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name" +retriever = HybridCypherRetriever( + driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder +) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "Who are the fremen?" +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index b0413b0a..704a3841 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -1,9 +1,9 @@ from neo4j import GraphDatabase from random import random +from neo4j_genai import HybridRetriever from neo4j_genai.embedder import Embedder from neo4j_genai.indexes import create_vector_index, create_fulltext_index -from neo4j_genai.retrievers import HybridRetriever URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index 89c3ad57..bade3d75 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .retrievers import VectorRetriever, VectorCypherRetriever, HybridRetriever +from .retrievers.vector import VectorRetriever, VectorCypherRetriever +from .retrievers.hybrid import HybridRetriever, HybridCypherRetriever - -__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"] +__all__ = [ + "VectorRetriever", + "VectorCypherRetriever", + "HybridRetriever", + "HybridCypherRetriever", +] diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py new file mode 100644 index 00000000..752677d5 --- /dev/null +++ b/src/neo4j_genai/neo4j_queries.py @@ -0,0 +1,56 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from neo4j_genai.types import SearchType + + +def get_search_query( + search_type: SearchType, + return_properties: Optional[list[str]] = None, + retrieval_query: Optional[str] = None, +): + query_map = { + SearchType.VECTOR: ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + ), + SearchType.HYBRID: ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + ), + } + + base_query = query_map[search_type] + additional_query = "" + + if retrieval_query: + additional_query += retrieval_query + elif return_properties: + return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) + additional_query += "YIELD node, score " + additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" + else: + additional_query += "RETURN node, score" + + return base_query + additional_query diff --git a/src/neo4j_genai/retrievers/__init__.py b/src/neo4j_genai/retrievers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py new file mode 100644 index 00000000..dc483eb6 --- /dev/null +++ b/src/neo4j_genai/retrievers/base.py @@ -0,0 +1,59 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any + +from neo4j import Driver + + +class Retriever(ABC): + """ + Abstract class for Neo4j retrievers + """ + + def __init__(self, driver: Driver): + self.driver = driver + self._verify_version() + + def _verify_version(self) -> None: + """ + Check if the connected Neo4j database version supports vector indexing. + + Queries the Neo4j database to retrieve its version and compares it + against a target version (5.18.1) that is known to support vector + indexing. Raises a ValueError if the connected Neo4j version is + not supported. + """ + records, _, _ = self.driver.execute_query("CALL dbms.components()") + version = records[0]["versions"][0] + + if "aura" in version: + version_tuple = ( + *tuple(map(int, version.split("-")[0].split("."))), + 0, + ) + target_version = (5, 18, 0) + else: + version_tuple = tuple(map(int, version.split("."))) + target_version = (5, 18, 1) + + if version_tuple < target_version: + raise ValueError( + "This package only supports Neo4j version 5.18.1 or greater" + ) + + @abstractmethod + def search(self, *args, **kwargs) -> Any: + pass diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py new file mode 100644 index 00000000..a9311d06 --- /dev/null +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -0,0 +1,161 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any + +from neo4j import Record, Driver +from pydantic import ValidationError + +from neo4j_genai.embedder import Embedder +from neo4j_genai.retrievers.base import Retriever +from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel +from neo4j_genai.neo4j_queries import get_search_query + + +class HybridRetriever(Retriever): + def __init__( + self, + driver: Driver, + vector_index_name: str, + fulltext_index_name: str, + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + ) -> None: + super().__init__(driver) + self.vector_index_name = vector_index_name + self.fulltext_index_name = fulltext_index_name + self.embedder = embedder + self.return_properties = return_properties + + def search( + self, + query_text: str, + query_vector: Optional[list[float]] = None, + top_k: int = 5, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + Both query_vector and query_text can be provided. + If query_vector is provided, then it will be preferred over the embedded query_text + for the vector search. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) + Args: + query_text (str): The text to get the closest neighbors of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridSearchModel( + vector_index_name=self.vector_index_name, + fulltext_index_name=self.fulltext_index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text and not query_vector: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + query_vector = self.embedder.embed_query(query_text) + parameters["query_vector"] = query_vector + + search_query = get_search_query(SearchType.HYBRID, self.return_properties) + + records, _, _ = self.driver.execute_query(search_query, parameters) + return records + + +class HybridCypherRetriever(Retriever): + def __init__( + self, + driver: Driver, + vector_index_name: str, + fulltext_index_name: str, + retrieval_query: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self.vector_index_name = vector_index_name + self.fulltext_index_name = fulltext_index_name + self.retrieval_query = retrieval_query + self.embedder = embedder + + def search( + self, + query_text: str, + query_vector: Optional[list[float]] = None, + top_k: int = 5, + query_params: Optional[dict[str, Any]] = None, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + Both query_vector and query_text can be provided. + If query_vector is provided, then it will be preferred over the embedded query_text + for the vector search. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) + Args: + query_text (str): The text to get the closest neighbors of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridCypherSearchModel( + vector_index_name=self.vector_index_name, + fulltext_index_name=self.fulltext_index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + query_params=query_params, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text and not query_vector: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + query_vector = self.embedder.embed_query(query_text) + parameters["query_vector"] = query_vector + + if query_params: + for key, value in query_params.items(): + if key not in parameters: + parameters[key] = value + del parameters["query_params"] + + search_query = get_search_query( + SearchType.HYBRID, retrieval_query=self.retrieval_query + ) + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers/vector.py similarity index 52% rename from src/neo4j_genai/retrievers.py rename to src/neo4j_genai/retrievers/vector.py index 15fe31f1..771ce352 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -12,57 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod, ABC from typing import Optional, Any -from pydantic import ValidationError + from neo4j import Driver, Record -from .embedder import Embedder -from .types import ( - SimilaritySearchModel, +from neo4j_genai.retrievers.base import Retriever +from pydantic import ValidationError + +from neo4j_genai.embedder import Embedder +from neo4j_genai.types import ( VectorSearchRecord, + VectorSearchModel, VectorCypherSearchModel, - HybridModel, + SearchType, ) - - -class Retriever(ABC): - """ - Abstract class for Neo4j retrievers - """ - - def __init__(self, driver: Driver): - self.driver = driver - - def _verify_version(self) -> None: - """ - Check if the connected Neo4j database version supports vector indexing. - - Queries the Neo4j database to retrieve its version and compares it - against a target version (5.18.1) that is known to support vector - indexing. Raises a ValueError if the connected Neo4j version is - not supported. - """ - records, _, _ = self.driver.execute_query("CALL dbms.components()") - version = records[0]["versions"][0] - - if "aura" in version: - version_tuple = ( - *tuple(map(int, version.split("-")[0].split("."))), - 0, - ) - target_version = (5, 18, 0) - else: - version_tuple = tuple(map(int, version.split("."))) - target_version = (5, 18, 1) - - if version_tuple < target_version: - raise ValueError( - "This package only supports Neo4j version 5.18.1 or greater" - ) - - @abstractmethod - def search(self, *args, **kwargs) -> Any: - pass +from neo4j_genai.neo4j_queries import get_search_query class VectorRetriever(Retriever): @@ -79,7 +42,6 @@ def __init__( return_properties: Optional[list[str]] = None, ) -> None: super().__init__(driver) - self._verify_version() self.index_name = index_name self.return_properties = return_properties self.embedder = embedder @@ -109,7 +71,7 @@ def search( list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: - validated_data = SimilaritySearchModel( + validated_data = VectorSearchModel( index_name=self.index_name, top_k=top_k, query_vector=query_vector, @@ -128,20 +90,9 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - db_query_string = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR, self.return_properties) - if self.return_properties: - return_properties_cypher = ", ".join( - [f".{prop}" for prop in self.return_properties] - ) - db_query_string += ( - f"RETURN node {{{return_properties_cypher}}} as node, score" - ) - - records, _, _ = self.driver.execute_query(db_query_string, parameters) + records, _, _ = self.driver.execute_query(search_query, parameters) try: return [ @@ -169,7 +120,6 @@ def __init__( embedder: Optional[Embedder] = None, ) -> None: super().__init__(driver) - self._verify_version() self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder @@ -225,96 +175,8 @@ def search( parameters[key] = value del parameters["query_params"] - query_prefix = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ - search_query = query_prefix + self.retrieval_query - records, _, _ = self.driver.execute_query(search_query, parameters) - return records - - -class HybridRetriever(Retriever): - def __init__( - self, - driver: Driver, - vector_index_name: str, - fulltext_index_name: str, - embedder: Optional[Embedder] = None, - return_properties: Optional[list[str]] = None, - ) -> None: - super().__init__(driver) - self._verify_version() - self.vector_index_name = vector_index_name - self.fulltext_index_name = fulltext_index_name - self.embedder = embedder - self.return_properties = return_properties - - def search( - self, - query_text: str, - query_vector: Optional[list[float]] = None, - top_k: int = 5, - ) -> list[Record]: - """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. - Both query_vector and query_text can be provided. - If query_vector is provided, then it will be preferred over the embedded query_text - for the vector search. - See the following documentation for more details: - - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) - - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) - - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) - Args: - query_text (str): The text to get the closest neighbors of. - query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. - top_k (int, optional): The number of neighbors to return. Defaults to 5. - Raises: - ValueError: If validation of the input arguments fail. - ValueError: If no embedder is provided. - Returns: - list[Record]: The results of the search query - """ - try: - validated_data = HybridModel( - vector_index_name=self.vector_index_name, - fulltext_index_name=self.fulltext_index_name, - top_k=top_k, - query_vector=query_vector, - query_text=query_text, - ) - except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") - - parameters = validated_data.model_dump(exclude_none=True) - - if query_text and not query_vector: - if not self.embedder: - raise ValueError("Embedding method required for text query.") - query_vector = self.embedder.embed_query(query_text) - parameters["query_vector"] = query_vector - - search_query = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + search_query = get_search_query( + SearchType.VECTOR, retrieval_query=self.retrieval_query ) - - if self.return_properties: - return_properties_cypher = ", ".join( - [f".{prop}" for prop in self.return_properties] - ) - search_query += "YIELD node, score " - search_query += f"RETURN node {{{return_properties_cypher}}} as node, score" - else: - search_query += "RETURN node, score" - records, _, _ = self.driver.execute_query(search_query, parameters) return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 5747aeea..67a31175 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from enum import Enum from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator from neo4j import Driver @@ -53,7 +53,7 @@ def check_node_properties_not_empty(cls, v): return v -class SimilaritySearchModel(BaseModel): +class VectorSearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None @@ -72,13 +72,24 @@ def check_query(cls, values): return values -class VectorCypherSearchModel(SimilaritySearchModel): +class VectorCypherSearchModel(VectorSearchModel): query_params: Optional[dict[str, Any]] = None -class HybridModel(BaseModel): +class HybridSearchModel(BaseModel): vector_index_name: str fulltext_index_name: str query_text: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None + + +class HybridCypherSearchModel(HybridSearchModel): + query_params: Optional[dict[str, Any]] = None + + +class SearchType(str, Enum): + """Enumerator of the search strategies.""" + + VECTOR = "vector" + HYBRID = "hybrid" diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 00000000..555388ae --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,76 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.types import SearchType + + +def test_vector_search_basic(): + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "RETURN node, score" + ) + result = get_search_query(SearchType.VECTOR) + assert result == expected + + +def test_hybrid_search_basic(): + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + result = get_search_query(SearchType.HYBRID) + assert result == expected + + +def test_vector_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.VECTOR, return_properties=properties) + assert result == expected + + +def test_hybrid_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + + retrieval_query + ) + result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + assert result == expected diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 60cd6c78..2f7a6113 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -18,9 +18,10 @@ from neo4j.exceptions import CypherSyntaxError -from neo4j_genai import VectorRetriever -from neo4j_genai.retrievers import VectorCypherRetriever, HybridRetriever -from neo4j_genai.types import VectorSearchRecord +from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever +from neo4j_genai.retrievers.hybrid import HybridCypherRetriever +from neo4j_genai.types import VectorSearchRecord, SearchType +from neo4j_genai.neo4j_queries import get_search_query def test_vector_retriever_supported_aura_version(driver): @@ -59,18 +60,13 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name) - retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, None, ] - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) @@ -82,7 +78,6 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): "query_vector": query_vector, }, ) - assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @@ -91,28 +86,20 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) - driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, None, ] - - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) custom_embeddings.embed_query.assert_called_once_with(query_text) - driver.execute_query.assert_called_once_with( search_query, { @@ -130,7 +117,6 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 @@ -145,17 +131,11 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - RETURN node {.node-property-1, .node-property-2} as node, score - """ + search_query = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) custom_embeddings.embed_query.assert_called_once_with(query_text) - driver.execute_query.assert_called_once_with( search_query.rstrip(), { @@ -164,7 +144,6 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): "query_vector": embed_query_vector, }, ) - assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @@ -222,18 +201,13 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name) - retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], None, None, ] - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -267,10 +241,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -279,7 +250,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): custom_embeddings.embed_query.assert_called_once_with(query_text) driver.execute_query.assert_called_once_with( - search_query + retrieval_query, + search_query, { "index_name": index_name, "top_k": top_k, @@ -314,10 +285,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = """ - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - """ + search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -328,7 +296,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): custom_embeddings.embed_query.assert_called_once_with(query_text) driver.execute_query.assert_called_once_with( - search_query + retrieval_query, + search_query, { "index_name": index_name, "top_k": top_k, @@ -372,30 +340,15 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): fulltext_index_name = "my-fulltext-index" query_text = "may thy knife chip and shatter" top_k = 5 - retriever = HybridRetriever( driver, vector_index_name, fulltext_index_name, custom_embeddings ) - retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, None, ] - search_query = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node, score" - ) + search_query = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -425,30 +378,15 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( fulltext_index_name = "my-fulltext-index" query_text = "may thy knife chip and shatter" top_k = 5 - retriever = HybridRetriever( driver, vector_index_name, fulltext_index_name, custom_embeddings ) - retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, None, ] - search_query = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node, score" - ) + search_query = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -511,35 +449,71 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "YIELD node, score " - "RETURN node {.node-property-1, .node-property-2} as node, score" - ) + search_query = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) + custom_embeddings.embed_query.assert_called_once_with(query_text) + driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridCypherRetriever._verify_version") +def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + query_params = { + "param": "dummy-param", + } + retriever = HybridCypherRetriever( + driver, + vector_index_name, + fulltext_index_name, + retrieval_query, + custom_embeddings, + ) + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + + records = retriever.search( + query_text=query_text, + top_k=top_k, + query_params=query_params, + ) + custom_embeddings.embed_query.assert_called_once_with(query_text) driver.execute_query.assert_called_once_with( - search_query.rstrip(), + search_query, { "vector_index_name": vector_index_name, "top_k": top_k, "query_text": query_text, "fulltext_index_name": fulltext_index_name, "query_vector": embed_query_vector, + "param": "dummy-param", }, ) - assert records == [{"node": "dummy-node", "score": 1.0}] + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]