Skip to content

Commit

Permalink
Adds HybridSearchRetriever and creates abstract base class Retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 19, 2024
1 parent de74a0f commit 6529ca1
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 28 deletions.
62 changes: 62 additions & 0 deletions examples/hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from neo4j import GraphDatabase

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index, drop_index, create_fulltext_index
from neo4j_genai.retrievers import HybridSearchRetriever

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
FULLTEXT_INDEX_NAME = "fulltext-index-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


embedder = CustomEmbedder()

# Creating the index
drop_index(driver, INDEX_NAME)
drop_index(driver, FULLTEXT_INDEX_NAME)
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)
create_fulltext_index(
driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"]
)

# Initialize the retriever
retriever = HybridSearchRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder)

# Upsert the query
vector = [random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (n:Document {id: $id})"
"WITH n "
"CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)"
"RETURN n"
)
parameters = {
"id": 0,
"vector": vector,
}
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "hello world"
fulltext_query = "fremen"
print(retriever.search(query_text=query_text, fulltext_query=fulltext_query, top_k=5))
2 changes: 1 addition & 1 deletion src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_fulltext_index(
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")

query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
)
Expand Down
138 changes: 117 additions & 21 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod, ABC
from typing import Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from neo4j import Driver, Record
from .embedder import Embedder
from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel
from .types import (
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
HybridSearchModel,
)


class VectorRetriever:
class Retriever(ABC):
"""
Provides retrieval methods using vector search over embeddings
Abstract class for Neo4j retrievers
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
def __init__(self, driver: Driver):
self.driver = driver
self._verify_version()
self.index_name = index_name
self.embedder = embedder

def _verify_version(self) -> None:
"""
Expand Down Expand Up @@ -63,12 +60,33 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

@abstractmethod
def search(self, *args, **kwargs) -> Any:
pass


class VectorRetriever(Retriever):
"""
Provides retrieval method using vector search over embeddings
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.embedder = embedder

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Neo4jRecord]:
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -85,7 +103,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -115,7 +133,7 @@ def search(

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
VectorSearchRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
Expand All @@ -125,7 +143,7 @@ def search(
)


class VectorCypherRetriever(VectorRetriever):
class VectorCypherRetriever(Retriever):
"""
Provides retrieval method using vector similarity and custom Cypher query
"""
Expand All @@ -137,7 +155,7 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.retrieval_query = retrieval_query
Expand All @@ -149,7 +167,7 @@ def search(
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Neo4jRecord]:
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -167,7 +185,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
Any: The results of the search query
list[Record]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down Expand Up @@ -201,3 +219,81 @@ def search(
search_query = query_prefix + self.retrieval_query
records, _, _ = self.driver.execute_query(search_query, parameters)
return records


class HybridSearchRetriever(Retriever):
def __init__(
self,
driver: Driver,
index_name: str,
fulltext_index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.fulltext_index_name = fulltext_index_name
self.embedder = embedder

def search(
self,
fulltext_query: str,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None.
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
list[Record]: The results of the search query
"""
try:
validated_data = HybridSearchModel(
index_name=self.index_name,
fulltext_index_name=self.fulltext_index_name,
fulltext_query=fulltext_query,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

search_query = (
"CALL { "
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $fulltext_query, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node, score"
)
records, _, _ = self.driver.execute_query(search_query, parameters)
return records
7 changes: 6 additions & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from neo4j import Driver


class Neo4jRecord(BaseModel):
class VectorSearchRecord(BaseModel):
node: Any
score: float

Expand Down Expand Up @@ -78,3 +78,8 @@ def check_query(cls, values):

class VectorCypherSearchModel(SimilaritySearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridSearchModel(VectorCypherSearchModel):
fulltext_index_name: str
fulltext_query: str
4 changes: 2 additions & 2 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand All @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver):
label = "node-label"
text_node_properties = ["property-1", "property-2"]
create_query = (
"CREATE FULLTEXT INDEX $name"
"CREATE FULLTEXT INDEX $name "
f"FOR (n:`{label}`) ON EACH "
f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]"
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import VectorCypherRetriever
from neo4j_genai.types import Neo4jRecord
from neo4j_genai.types import VectorSearchRecord


def test_vector_retriever_supported_aura_version(driver):
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):
Expand Down

0 comments on commit 6529ca1

Please sign in to comment.