Skip to content

Commit

Permalink
Adds HybridCypherRetriever, refactor query construction, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 29, 2024
1 parent 175a78c commit 9cca2a3
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 40 deletions.
62 changes: 62 additions & 0 deletions examples/hybrid_cypher_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from neo4j import GraphDatabase

from random import random
from neo4j_genai import HybridCypherRetriever
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index, create_fulltext_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
FULLTEXT_INDEX_NAME = "fulltext-index-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


embedder = CustomEmbedder()

# Creating the index
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)
create_fulltext_index(
driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"]
)

# Initialize the retriever
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name"
retriever = HybridCypherRetriever(
driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder
)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (n:Document {id: $id})"
"WITH n "
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)"
"RETURN n"
)
parameters = {
"id": 0,
"vector": vector,
}
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "Who are the fremen?"
print(retriever.search(query_text=query_text, top_k=5))
10 changes: 7 additions & 3 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# limitations under the License.

from .retrievers.vector import VectorRetriever, VectorCypherRetriever
from .retrievers.hybrid import HybridRetriever
from .retrievers.hybrid import HybridRetriever, HybridCypherRetriever


__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"]
__all__ = [
"VectorRetriever",
"VectorCypherRetriever",
"HybridRetriever",
"HybridCypherRetriever",
]
32 changes: 19 additions & 13 deletions src/neo4j_genai/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
from neo4j_genai.types import SearchType


def get_search_query(search_type: SearchType, return_properties: Optional[list[str]] = None,):
def get_search_query(
search_type: SearchType,
return_properties: Optional[list[str]] = None,
retrieval_query: Optional[str] = None,
):
query_map = {
SearchType.Vector: (
"CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score "
SearchType.VECTOR: (
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) "
),
SearchType.Hybrid: (
SearchType.HYBRID: (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
Expand All @@ -37,14 +41,16 @@ def get_search_query(search_type: SearchType, return_properties: Optional[list[s
),
}

search_query = query_map[search_type]
base_query = query_map[search_type]
additional_query = ""

if return_properties:
return_properties_cypher = ", ".join(
[f".{prop}" for prop in return_properties]
)
search_query += "YIELD node, score "
search_query += f"RETURN node {{{return_properties_cypher}}} as node, score"
if retrieval_query:
additional_query += retrieval_query
elif return_properties:
return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties])
additional_query += "YIELD node, score "
additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score"
else:
search_query += "RETURN node, score"
return search_query
additional_query += "RETURN node, score"

return base_query + additional_query
83 changes: 79 additions & 4 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Any

from neo4j import Record, Driver
from pydantic import ValidationError

from neo4j_genai.embedder import Embedder
from neo4j_genai.retrievers.base import Retriever
from neo4j_genai.types import HybridModel, SearchType
from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel
from neo4j_genai.queries import get_search_query


Expand Down Expand Up @@ -64,7 +64,7 @@ def search(
list[Record]: The results of the search query
"""
try:
validated_data = HybridModel(
validated_data = HybridSearchModel(
vector_index_name=self.vector_index_name,
fulltext_index_name=self.fulltext_index_name,
top_k=top_k,
Expand All @@ -82,7 +82,82 @@ def search(
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query = get_search_query(SearchType.Hybrid, self.return_properties)
search_query = get_search_query(SearchType.HYBRID, self.return_properties)

records, _, _ = self.driver.execute_query(search_query, parameters)
return records


class HybridCypherRetriever(Retriever):
def __init__(
self,
driver: Driver,
vector_index_name: str,
fulltext_index_name: str,
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.retrieval_query = retrieval_query
self.embedder = embedder

def search(
self,
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
If query_vector is provided, then it will be preferred over the embedded query_text
for the vector search.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
- [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes)
Args:
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
"""
try:
validated_data = HybridCypherSearchModel(
vector_index_name=self.vector_index_name,
fulltext_index_name=self.fulltext_index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
query_params=query_params,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text and not query_vector:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

if query_params:
for key, value in query_params.items():
if key not in parameters:
parameters[key] = value
del parameters["query_params"]

search_query = get_search_query(
SearchType.HYBRID, retrieval_query=self.retrieval_query
)
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
10 changes: 6 additions & 4 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from neo4j_genai.embedder import Embedder
from neo4j_genai.types import (
VectorSearchRecord,
SimilaritySearchModel,
VectorSearchModel,
VectorCypherSearchModel,
SearchType,
)
Expand Down Expand Up @@ -72,7 +72,7 @@ def search(
list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
validated_data = VectorSearchModel(
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
Expand All @@ -91,7 +91,7 @@ def search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

search_query = get_search_query(SearchType.Vector, self.return_properties)
search_query = get_search_query(SearchType.VECTOR, self.return_properties)

records, _, _ = self.driver.execute_query(search_query, parameters)

Expand Down Expand Up @@ -177,6 +177,8 @@ def search(
parameters[key] = value
del parameters["query_params"]

search_query = get_search_query(SearchType.Vector) + self.retrieval_query
search_query = get_search_query(
SearchType.VECTOR, retrieval_query=self.retrieval_query
)
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
14 changes: 9 additions & 5 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def check_node_properties_not_empty(cls, v):
return v


class SimilaritySearchModel(BaseModel):
class VectorSearchModel(BaseModel):
index_name: str
top_k: PositiveInt = 5
query_vector: Optional[list[float]] = None
Expand All @@ -72,20 +72,24 @@ def check_query(cls, values):
return values


class VectorCypherSearchModel(SimilaritySearchModel):
class VectorCypherSearchModel(VectorSearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridModel(BaseModel):
class HybridSearchModel(BaseModel):
vector_index_name: str
fulltext_index_name: str
query_text: str
top_k: PositiveInt = 5
query_vector: Optional[list[float]] = None


class HybridCypherSearchModel(HybridSearchModel):
query_params: Optional[dict[str, Any]] = None


class SearchType(str, Enum):
"""Enumerator of the search strategies."""

Vector = "vector"
Hybrid = "hybrid"
VECTOR = "vector"
HYBRID = "hybrid"
75 changes: 75 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neo4j_genai.queries import get_search_query
from neo4j_genai.types import SearchType


def test_vector_search_basic():
expected = (
"CALL db.index.vector.queryNodes($index, $k, $embedding) " "RETURN node, score"
)
result = get_search_query(SearchType.VECTOR)
assert result == expected


def test_hybrid_search_basic():
expected = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node, score"
)
result = get_search_query(SearchType.HYBRID)
assert result == expected


def test_vector_search_with_properties():
properties = ["name", "age"]
expected = (
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"RETURN node {.name, .age} as node, score"
)
result = get_search_query(SearchType.VECTOR, return_properties=properties)
assert result == expected


def test_hybrid_search_with_retrieval_query():
retrieval_query = "MATCH (n) RETURN n LIMIT 10"
expected = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
+ retrieval_query
)
result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query)
assert result == expected
Loading

0 comments on commit 9cca2a3

Please sign in to comment.