diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py new file mode 100644 index 000000000..7fffd1e6c --- /dev/null +++ b/examples/hybrid_search.py @@ -0,0 +1,62 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, drop_index, create_fulltext_index +from neo4j_genai.retrievers import HybridSearchRetriever + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +drop_index(driver, INDEX_NAME) +drop_index(driver, FULLTEXT_INDEX_NAME) +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retriever = HybridSearchRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "hello world" +fulltext_query = "fremen" +print(retriever.search(query_text=query_text, fulltext_query=fulltext_query, top_k=5)) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b457b3014..b09901e08 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -98,7 +98,7 @@ def create_fulltext_index( raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" ) diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 2653ba66e..5ffa30fc8 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -12,29 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import abstractmethod, ABC from typing import Optional, Any from pydantic import ValidationError -from neo4j import Driver +from neo4j import Driver, Record from .embedder import Embedder -from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel +from .types import ( + SimilaritySearchModel, + VectorSearchRecord, + VectorCypherSearchModel, + HybridSearchModel, +) -class VectorRetriever: +class Retriever(ABC): """ - Provides retrieval methods using vector search over embeddings + Abstract class for Neo4j retrievers """ - def __init__( - self, - driver: Driver, - index_name: str, - embedder: Optional[Embedder] = None, - ) -> None: + def __init__(self, driver: Driver): self.driver = driver - self._verify_version() - self.index_name = index_name - self.embedder = embedder def _verify_version(self) -> None: """ @@ -63,12 +60,33 @@ def _verify_version(self) -> None: "This package only supports Neo4j version 5.18.1 or greater" ) + @abstractmethod + def search(self, *args, **kwargs) -> Any: + pass + + +class VectorRetriever(Retriever): + """ + Provides retrieval method using vector search over embeddings + """ + + def __init__( + self, + driver: Driver, + index_name: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.embedder = embedder + def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> list[Neo4jRecord]: + ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -85,7 +103,7 @@ def search( ValueError: If no embedder is provided. Returns: - list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. + list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: validated_data = SimilaritySearchModel( @@ -115,7 +133,7 @@ def search( try: return [ - Neo4jRecord(node=record["node"], score=record["score"]) + VectorSearchRecord(node=record["node"], score=record["score"]) for record in records ] except ValidationError as e: @@ -125,7 +143,7 @@ def search( ) -class VectorCypherRetriever(VectorRetriever): +class VectorCypherRetriever(Retriever): """ Provides retrieval method using vector similarity and custom Cypher query """ @@ -137,7 +155,7 @@ def __init__( retrieval_query: str, embedder: Optional[Embedder] = None, ) -> None: - self.driver = driver + super().__init__(driver) self._verify_version() self.index_name = index_name self.retrieval_query = retrieval_query @@ -149,7 +167,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Neo4jRecord]: + ) -> list[Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -167,7 +185,7 @@ def search( ValueError: If no embedder is provided. Returns: - Any: The results of the search query + list[Record]: The results of the search query """ try: validated_data = VectorCypherSearchModel( @@ -201,3 +219,81 @@ def search( search_query = query_prefix + self.retrieval_query records, _, _ = self.driver.execute_query(search_query, parameters) return records + + +class HybridSearchRetriever(Retriever): + def __init__( + self, + driver: Driver, + index_name: str, + fulltext_index_name: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.fulltext_index_name = fulltext_index_name + self.embedder = embedder + + def search( + self, + fulltext_query: str, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridSearchModel( + index_name=self.index_name, + fulltext_index_name=self.fulltext_index_name, + fulltext_query=fulltext_query, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $fulltext_query, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index dc65a481d..cde5b1f70 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -18,7 +18,7 @@ from neo4j import Driver -class Neo4jRecord(BaseModel): +class VectorSearchRecord(BaseModel): node: Any score: float @@ -78,3 +78,8 @@ def check_query(cls, values): class VectorCypherSearchModel(SimilaritySearchModel): query_params: Optional[dict[str, Any]] = None + + +class HybridSearchModel(VectorCypherSearchModel): + fulltext_index_name: str + fulltext_query: str diff --git a/tests/test_indexes.py b/tests/test_indexes.py index ae2f98c32..c624607d5 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -89,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) @@ -116,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 77ce30e8a..aad9cbcfc 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -20,7 +20,7 @@ from neo4j_genai import VectorRetriever from neo4j_genai.retrievers import VectorCypherRetriever -from neo4j_genai.types import Neo4jRecord +from neo4j_genai.types import VectorSearchRecord def test_vector_retriever_supported_aura_version(driver): @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @patch("neo4j_genai.VectorRetriever._verify_version") @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):