Skip to content

Commit

Permalink
Adds HybridSearchRetriever and creates abstract base class Retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 25, 2024
1 parent 06bfe7e commit 77a9873
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 18 deletions.
61 changes: 61 additions & 0 deletions examples/hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from neo4j import GraphDatabase

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index, create_fulltext_index, drop_index
from neo4j_genai.retrievers import HybridSearchRetriever

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
FULLTEXT_INDEX_NAME = "fulltext-index-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


embedder = CustomEmbedder()

# Creating the index
drop_index(driver, INDEX_NAME)
drop_index(driver, FULLTEXT_INDEX_NAME)
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)
create_fulltext_index(
driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"]
)

# Initialize the retriever
retriever = HybridSearchRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (n:Document {id: $id})"
"WITH n "
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)"
"RETURN n"
)
parameters = {
"id": 0,
"vector": vector,
}
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "Who are the fremen?"
print(retriever.search(query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

[tool.poetry]
name = "neo4j-genai"
version = "0.1.3"
version = "0.1.4"
description = "Python package to allow easy integration to Neo4j's GenAI features"
authors = ["Neo4j, Inc <[email protected]>"]
license = "Apache License, Version 2.0"
Expand Down
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .retrievers import VectorRetriever, VectorCypherRetriever
from .retrievers import VectorRetriever, VectorCypherRetriever, HybridSearchRetriever


__all__ = ["VectorRetriever", "VectorCypherRetriever"]
__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridSearchRetriever"]
2 changes: 1 addition & 1 deletion src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_fulltext_index(
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")

query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
Expand Down
72 changes: 72 additions & 0 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
HybridSearchModel,
)


Expand Down Expand Up @@ -231,3 +232,74 @@ def search(
search_query = query_prefix + self.retrieval_query
records, _, _ = self.driver.execute_query(search_query, parameters)
return records


class HybridSearchRetriever(Retriever):
def __init__(
self,
driver: Driver,
index_name: str,
fulltext_index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.fulltext_index_name = fulltext_index_name
self.embedder = embedder

def search(
self,
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
"""
try:
validated_data = HybridSearchModel(
index_name=self.index_name,
fulltext_index_name=self.fulltext_index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text and not query_vector:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query = (
"CALL { "
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node, score"
)
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
5 changes: 5 additions & 0 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ def check_query(cls, values):

class VectorCypherSearchModel(SimilaritySearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridSearchModel(SimilaritySearchModel):
fulltext_index_name: str
query_text: str
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import pytest
from neo4j_genai import VectorRetriever, VectorCypherRetriever
from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridSearchRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -37,3 +37,9 @@ def vector_cypher_retriever(_verify_version_mock, driver):
RETURN node.id AS node_id, node.text AS text, score
"""
return VectorCypherRetriever(driver, "my-index", retrieval_query)


@pytest.fixture
@patch("neo4j_genai.HybridSearchRetriever._verify_version")
def hybrid_search_retriever(_verify_version_mock, driver):
return HybridSearchRetriever(driver, "my-index", "my-fulltext-index")
4 changes: 2 additions & 2 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand All @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand Down
88 changes: 77 additions & 11 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from neo4j.exceptions import CypherSyntaxError

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import VectorCypherRetriever
from neo4j_genai.retrievers import VectorCypherRetriever, HybridSearchRetriever
from neo4j_genai.types import VectorSearchRecord


Expand Down Expand Up @@ -55,14 +55,12 @@ def test_vector_retriever_no_supported_version(driver):

@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)
retriever = VectorRetriever(driver, index_name)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
Expand All @@ -76,8 +74,6 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):

records = retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

retriever.driver.execute_query.assert_called_once_with(
search_query,
{
Expand Down Expand Up @@ -222,14 +218,12 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri

@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
custom_embeddings = MagicMock()

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)
retriever = VectorRetriever(driver, index_name)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": "adsa"}],
Expand All @@ -244,8 +238,6 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
with pytest.raises(ValueError):
retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

retriever.driver.execute_query.assert_called_once_with(
search_query,
{
Expand Down Expand Up @@ -369,3 +361,77 @@ def test_retrieval_query_cypher_error(_verify_version_mock, driver):
query_text=query_text,
top_k=top_k,
)


@patch("neo4j_genai.HybridSearchRetriever._verify_version")
def test_hybrid_search_text_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
fulltext_index_name = "my-fulltext-index"
query_text = "may thy knife chip and shatter"
top_k = 5

retriever = HybridSearchRetriever(
driver, index_name, fulltext_index_name, custom_embeddings
)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
]
search_query = (
"CALL { "
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node, score"
)

records = retriever.search(query_text=query_text, top_k=top_k)

retriever.driver.execute_query.assert_called_once_with(
search_query,
{
"index_name": index_name,
"top_k": top_k,
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
},
)
custom_embeddings.embed_query.assert_called_once_with(query_text)
assert records == [{"node": "dummy-node", "score": 1.0}]


def test_error_when_hybrid_search_only_text_no_embedder(hybrid_search_retriever):
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query."):
hybrid_search_retriever.search(
query_text=query_text,
top_k=top_k,
)


def test_hybrid_search_retriever_search_missing_embedder_for_text(
hybrid_search_retriever,
):
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
hybrid_search_retriever.search(
query_text=query_text,
top_k=top_k,
)

0 comments on commit 77a9873

Please sign in to comment.