-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HybridCypherRetriever, refactors Retriever directory, refactor qu…
…ery construction (#22) * Refactored retrievers into a directory with separate files * Adds HybridCypherRetriever, refactor query construction, add tests * Update README (#20) * Minor changes to the README --------- Co-authored-by: Oskar Hane <[email protected]>
- Loading branch information
Showing
11 changed files
with
519 additions
and
253 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from neo4j import GraphDatabase | ||
|
||
from random import random | ||
from neo4j_genai import HybridCypherRetriever | ||
from neo4j_genai.embedder import Embedder | ||
from neo4j_genai.indexes import create_vector_index, create_fulltext_index | ||
|
||
URI = "neo4j://localhost:7687" | ||
AUTH = ("neo4j", "password") | ||
|
||
INDEX_NAME = "embedding-name" | ||
FULLTEXT_INDEX_NAME = "fulltext-index-name" | ||
DIMENSION = 1536 | ||
|
||
# Connect to Neo4j database | ||
driver = GraphDatabase.driver(URI, auth=AUTH) | ||
|
||
|
||
# Create Embedder object | ||
class CustomEmbedder(Embedder): | ||
def embed_query(self, text: str) -> list[float]: | ||
return [random() for _ in range(DIMENSION)] | ||
|
||
|
||
embedder = CustomEmbedder() | ||
|
||
# Creating the index | ||
create_vector_index( | ||
driver, | ||
INDEX_NAME, | ||
label="Document", | ||
property="propertyKey", | ||
dimensions=DIMENSION, | ||
similarity_fn="euclidean", | ||
) | ||
create_fulltext_index( | ||
driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] | ||
) | ||
|
||
# Initialize the retriever | ||
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name" | ||
retriever = HybridCypherRetriever( | ||
driver, INDEX_NAME, FULLTEXT_INDEX_NAME, retrieval_query, embedder | ||
) | ||
|
||
# Upsert the query | ||
vector = [random() for _ in range(DIMENSION)] | ||
insert_query = ( | ||
"MERGE (n:Document {id: $id})" | ||
"WITH n " | ||
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" | ||
"RETURN n" | ||
) | ||
parameters = { | ||
"id": 0, | ||
"vector": vector, | ||
} | ||
driver.execute_query(insert_query, parameters) | ||
|
||
# Perform the similarity search for a text query | ||
query_text = "Who are the fremen?" | ||
print(retriever.search(query_text=query_text, top_k=5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional | ||
|
||
from neo4j_genai.types import SearchType | ||
|
||
|
||
def get_search_query( | ||
search_type: SearchType, | ||
return_properties: Optional[list[str]] = None, | ||
retrieval_query: Optional[str] = None, | ||
): | ||
query_map = { | ||
SearchType.VECTOR: ( | ||
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " | ||
), | ||
SearchType.HYBRID: ( | ||
"CALL { " | ||
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " | ||
"YIELD node, score " | ||
"RETURN node, score UNION " | ||
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " | ||
"YIELD node, score " | ||
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max " | ||
"UNWIND nodes AS n " | ||
"RETURN n.node AS node, (n.score / max) AS score " | ||
"} " | ||
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " | ||
), | ||
} | ||
|
||
base_query = query_map[search_type] | ||
additional_query = "" | ||
|
||
if retrieval_query: | ||
additional_query += retrieval_query | ||
elif return_properties: | ||
return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) | ||
additional_query += "YIELD node, score " | ||
additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" | ||
else: | ||
additional_query += "RETURN node, score" | ||
|
||
return base_query + additional_query |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
from neo4j import Driver | ||
|
||
|
||
class Retriever(ABC): | ||
""" | ||
Abstract class for Neo4j retrievers | ||
""" | ||
|
||
def __init__(self, driver: Driver): | ||
self.driver = driver | ||
self._verify_version() | ||
|
||
def _verify_version(self) -> None: | ||
""" | ||
Check if the connected Neo4j database version supports vector indexing. | ||
Queries the Neo4j database to retrieve its version and compares it | ||
against a target version (5.18.1) that is known to support vector | ||
indexing. Raises a ValueError if the connected Neo4j version is | ||
not supported. | ||
""" | ||
records, _, _ = self.driver.execute_query("CALL dbms.components()") | ||
version = records[0]["versions"][0] | ||
|
||
if "aura" in version: | ||
version_tuple = ( | ||
*tuple(map(int, version.split("-")[0].split("."))), | ||
0, | ||
) | ||
target_version = (5, 18, 0) | ||
else: | ||
version_tuple = tuple(map(int, version.split("."))) | ||
target_version = (5, 18, 1) | ||
|
||
if version_tuple < target_version: | ||
raise ValueError( | ||
"This package only supports Neo4j version 5.18.1 or greater" | ||
) | ||
|
||
@abstractmethod | ||
def search(self, *args, **kwargs) -> Any: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Any | ||
|
||
from neo4j import Record, Driver | ||
from pydantic import ValidationError | ||
|
||
from neo4j_genai.embedder import Embedder | ||
from neo4j_genai.retrievers.base import Retriever | ||
from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel | ||
from neo4j_genai.neo4j_queries import get_search_query | ||
|
||
|
||
class HybridRetriever(Retriever): | ||
def __init__( | ||
self, | ||
driver: Driver, | ||
vector_index_name: str, | ||
fulltext_index_name: str, | ||
embedder: Optional[Embedder] = None, | ||
return_properties: Optional[list[str]] = None, | ||
) -> None: | ||
super().__init__(driver) | ||
self.vector_index_name = vector_index_name | ||
self.fulltext_index_name = fulltext_index_name | ||
self.embedder = embedder | ||
self.return_properties = return_properties | ||
|
||
def search( | ||
self, | ||
query_text: str, | ||
query_vector: Optional[list[float]] = None, | ||
top_k: int = 5, | ||
) -> list[Record]: | ||
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. | ||
Both query_vector and query_text can be provided. | ||
If query_vector is provided, then it will be preferred over the embedded query_text | ||
for the vector search. | ||
See the following documentation for more details: | ||
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) | ||
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) | ||
- [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) | ||
Args: | ||
query_text (str): The text to get the closest neighbors of. | ||
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. | ||
top_k (int, optional): The number of neighbors to return. Defaults to 5. | ||
Raises: | ||
ValueError: If validation of the input arguments fail. | ||
ValueError: If no embedder is provided. | ||
Returns: | ||
list[Record]: The results of the search query | ||
""" | ||
try: | ||
validated_data = HybridSearchModel( | ||
vector_index_name=self.vector_index_name, | ||
fulltext_index_name=self.fulltext_index_name, | ||
top_k=top_k, | ||
query_vector=query_vector, | ||
query_text=query_text, | ||
) | ||
except ValidationError as e: | ||
raise ValueError(f"Validation failed: {e.errors()}") | ||
|
||
parameters = validated_data.model_dump(exclude_none=True) | ||
|
||
if query_text and not query_vector: | ||
if not self.embedder: | ||
raise ValueError("Embedding method required for text query.") | ||
query_vector = self.embedder.embed_query(query_text) | ||
parameters["query_vector"] = query_vector | ||
|
||
search_query = get_search_query(SearchType.HYBRID, self.return_properties) | ||
|
||
records, _, _ = self.driver.execute_query(search_query, parameters) | ||
return records | ||
|
||
|
||
class HybridCypherRetriever(Retriever): | ||
def __init__( | ||
self, | ||
driver: Driver, | ||
vector_index_name: str, | ||
fulltext_index_name: str, | ||
retrieval_query: str, | ||
embedder: Optional[Embedder] = None, | ||
) -> None: | ||
super().__init__(driver) | ||
self.vector_index_name = vector_index_name | ||
self.fulltext_index_name = fulltext_index_name | ||
self.retrieval_query = retrieval_query | ||
self.embedder = embedder | ||
|
||
def search( | ||
self, | ||
query_text: str, | ||
query_vector: Optional[list[float]] = None, | ||
top_k: int = 5, | ||
query_params: Optional[dict[str, Any]] = None, | ||
) -> list[Record]: | ||
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. | ||
Both query_vector and query_text can be provided. | ||
If query_vector is provided, then it will be preferred over the embedded query_text | ||
for the vector search. | ||
See the following documentation for more details: | ||
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) | ||
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) | ||
- [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) | ||
Args: | ||
query_text (str): The text to get the closest neighbors of. | ||
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. | ||
top_k (int, optional): The number of neighbors to return. Defaults to 5. | ||
query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. | ||
Raises: | ||
ValueError: If validation of the input arguments fail. | ||
ValueError: If no embedder is provided. | ||
Returns: | ||
list[Record]: The results of the search query | ||
""" | ||
try: | ||
validated_data = HybridCypherSearchModel( | ||
vector_index_name=self.vector_index_name, | ||
fulltext_index_name=self.fulltext_index_name, | ||
top_k=top_k, | ||
query_vector=query_vector, | ||
query_text=query_text, | ||
query_params=query_params, | ||
) | ||
except ValidationError as e: | ||
raise ValueError(f"Validation failed: {e.errors()}") | ||
|
||
parameters = validated_data.model_dump(exclude_none=True) | ||
|
||
if query_text and not query_vector: | ||
if not self.embedder: | ||
raise ValueError("Embedding method required for text query.") | ||
query_vector = self.embedder.embed_query(query_text) | ||
parameters["query_vector"] = query_vector | ||
|
||
if query_params: | ||
for key, value in query_params.items(): | ||
if key not in parameters: | ||
parameters[key] = value | ||
del parameters["query_params"] | ||
|
||
search_query = get_search_query( | ||
SearchType.HYBRID, retrieval_query=self.retrieval_query | ||
) | ||
records, _, _ = self.driver.execute_query(search_query, parameters) | ||
return records |
Oops, something went wrong.