Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds HybridRetriever #14

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from neo4j import GraphDatabase

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index, create_fulltext_index
from neo4j_genai.retrievers import HybridRetriever

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
FULLTEXT_INDEX_NAME = "fulltext-index-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


embedder = CustomEmbedder()

# Creating the index
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)
create_fulltext_index(
driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"]
)

# Initialize the retriever
retriever = HybridRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (n:Document {id: $id})"
"WITH n "
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)"
"RETURN n"
)
parameters = {
"id": 0,
"vector": vector,
}
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "Who are the fremen?"
print(retriever.search(query_text=query_text, top_k=5))
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .retrievers import VectorRetriever, VectorCypherRetriever
from .retrievers import VectorRetriever, VectorCypherRetriever, HybridRetriever


__all__ = ["VectorRetriever", "VectorCypherRetriever"]
__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"]
2 changes: 1 addition & 1 deletion src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_fulltext_index(
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")

query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
Expand Down
87 changes: 87 additions & 0 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
HybridModel,
)


Expand Down Expand Up @@ -231,3 +232,89 @@ def search(
search_query = query_prefix + self.retrieval_query
records, _, _ = self.driver.execute_query(search_query, parameters)
return records


class HybridRetriever(Retriever):
def __init__(
self,
driver: Driver,
vector_index_name: str,
fulltext_index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.embedder = embedder
self.return_properties = return_properties

def search(
self,
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
If query_vector is provided, then it will be preferred over the embedded query_text
for the vector search.
willtai marked this conversation as resolved.
Show resolved Hide resolved
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
- [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes)
Args:
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
"""
try:
validated_data = HybridModel(
vector_index_name=self.vector_index_name,
fulltext_index_name=self.fulltext_index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text and not query_vector:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
)

if self.return_properties:
return_properties_cypher = ", ".join(
[f".{prop}" for prop in self.return_properties]
)
search_query += "YIELD node, score "
search_query += f"RETURN node {{{return_properties_cypher}}} as node, score"
else:
search_query += "RETURN node, score"

records, _, _ = self.driver.execute_query(search_query, parameters)
return records
8 changes: 8 additions & 0 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ def check_query(cls, values):

class VectorCypherSearchModel(SimilaritySearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridModel(BaseModel):
vector_index_name: str
fulltext_index_name: str
query_text: str
top_k: PositiveInt = 5
query_vector: Optional[list[float]] = None
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import pytest
from neo4j_genai import VectorRetriever, VectorCypherRetriever
from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -37,3 +37,9 @@ def vector_cypher_retriever(_verify_version_mock, driver):
RETURN node.id AS node_id, node.text AS text, score
"""
return VectorCypherRetriever(driver, "my-index", retrieval_query)


@pytest.fixture
@patch("neo4j_genai.HybridRetriever._verify_version")
def hybrid_retriever(_verify_version_mock, driver):
return HybridRetriever(driver, "my-index", "my-fulltext-index")
4 changes: 2 additions & 2 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand All @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand Down
Loading