From eaf27abd652f04050e0c3d9ee90fa2f71908b08d Mon Sep 17 00:00:00 2001 From: willtai Date: Fri, 26 Apr 2024 12:35:54 +0100 Subject: [PATCH] Adds HybridSearchRetriever and creates abstract base class Retriever (#14) --- examples/hybrid_search.py | 59 ++++++++++ src/neo4j_genai/__init__.py | 4 +- src/neo4j_genai/indexes.py | 2 +- src/neo4j_genai/retrievers.py | 87 +++++++++++++++ src/neo4j_genai/types.py | 8 ++ tests/conftest.py | 8 +- tests/test_indexes.py | 4 +- tests/test_retrievers.py | 196 ++++++++++++++++++++++++++++++++-- 8 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 examples/hybrid_search.py diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py new file mode 100644 index 00000000..b0413b0a --- /dev/null +++ b/examples/hybrid_search.py @@ -0,0 +1,59 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, create_fulltext_index +from neo4j_genai.retrievers import HybridRetriever + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retriever = HybridRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "Who are the fremen?" +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index 5676e9c4..89c3ad57 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .retrievers import VectorRetriever, VectorCypherRetriever +from .retrievers import VectorRetriever, VectorCypherRetriever, HybridRetriever -__all__ = ["VectorRetriever", "VectorCypherRetriever"] +__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"] diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b457b301..b09901e0 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -98,7 +98,7 @@ def create_fulltext_index( raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" ) diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 9ce9d6dc..15fe31f1 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -21,6 +21,7 @@ SimilaritySearchModel, VectorSearchRecord, VectorCypherSearchModel, + HybridModel, ) @@ -231,3 +232,89 @@ def search( search_query = query_prefix + self.retrieval_query records, _, _ = self.driver.execute_query(search_query, parameters) return records + + +class HybridRetriever(Retriever): + def __init__( + self, + driver: Driver, + vector_index_name: str, + fulltext_index_name: str, + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.vector_index_name = vector_index_name + self.fulltext_index_name = fulltext_index_name + self.embedder = embedder + self.return_properties = return_properties + + def search( + self, + query_text: str, + query_vector: Optional[list[float]] = None, + top_k: int = 5, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + Both query_vector and query_text can be provided. + If query_vector is provided, then it will be preferred over the embedded query_text + for the vector search. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) + Args: + query_text (str): The text to get the closest neighbors of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridModel( + vector_index_name=self.vector_index_name, + fulltext_index_name=self.fulltext_index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text and not query_vector: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + query_vector = self.embedder.embed_query(query_text) + parameters["query_vector"] = query_vector + + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + ) + + if self.return_properties: + return_properties_cypher = ", ".join( + [f".{prop}" for prop in self.return_properties] + ) + search_query += "YIELD node, score " + search_query += f"RETURN node {{{return_properties_cypher}}} as node, score" + else: + search_query += "RETURN node, score" + + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 6db895fd..5747aeea 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -74,3 +74,11 @@ def check_query(cls, values): class VectorCypherSearchModel(SimilaritySearchModel): query_params: Optional[dict[str, Any]] = None + + +class HybridModel(BaseModel): + vector_index_name: str + fulltext_index_name: str + query_text: str + top_k: PositiveInt = 5 + query_vector: Optional[list[float]] = None diff --git a/tests/conftest.py b/tests/conftest.py index 3c210a9f..b0359ec0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. import pytest -from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -37,3 +37,9 @@ def vector_cypher_retriever(_verify_version_mock, driver): RETURN node.id AS node_id, node.text AS text, score """ return VectorCypherRetriever(driver, "my-index", retrieval_query) + + +@pytest.fixture +@patch("neo4j_genai.HybridRetriever._verify_version") +def hybrid_retriever(_verify_version_mock, driver): + return HybridRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/test_indexes.py b/tests/test_indexes.py index ae2f98c3..c624607d 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 62900cb0..60cd6c78 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -19,7 +19,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever -from neo4j_genai.retrievers import VectorCypherRetriever +from neo4j_genai.retrievers import VectorCypherRetriever, HybridRetriever from neo4j_genai.types import VectorSearchRecord @@ -55,14 +55,12 @@ def test_vector_retriever_no_supported_version(driver): @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_happy_path(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], @@ -76,8 +74,6 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): records = retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -222,14 +218,12 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_bad_results(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], @@ -244,8 +238,6 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -369,3 +361,185 @@ def test_retrieval_query_cypher_error(_verify_version_mock, driver): query_text=query_text, top_k=top_k, ) + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_text_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + + records = retriever.search(query_text=query_text, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + custom_embeddings.embed_query.assert_called_once_with(query_text) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_favors_query_vector_over_embedding_vector( + _verify_version_mock, driver +): + embed_query_vector = [1.0 for _ in range(1536)] + query_vector = [2.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + + retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": query_vector, + }, + ) + custom_embeddings.embed_query.assert_not_called() + + +def test_error_when_hybrid_search_only_text_no_embedder(hybrid_retriever): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query."): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +def test_hybrid_search_retriever_search_missing_embedder_for_text( + hybrid_retriever, +): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_retriever_return_properties(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + return_properties = ["node-property-1", "node-property-2"] + retriever = HybridRetriever( + driver, + vector_index_name, + fulltext_index_name, + custom_embeddings, + return_properties, + ) + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "YIELD node, score " + "RETURN node {.node-property-1, .node-property-2} as node, score" + ) + + records = retriever.search(query_text=query_text, top_k=top_k) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query.rstrip(), + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + + assert records == [{"node": "dummy-node", "score": 1.0}]