diff --git a/.github/workflows/cla-check.yaml b/.github/workflows/cla-check.yaml new file mode 100644 index 000000000..9784e8ab8 --- /dev/null +++ b/.github/workflows/cla-check.yaml @@ -0,0 +1,32 @@ +name: "CLA Check" + +on: + pull_request_target: + branches: + - main + +jobs: + cla-check: + if: github.event.pull_request.user.login != 'renovate[bot]' + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4 + with: + repository: neo-technology/whitelist-check + token: ${{ secrets.CLA_CHECK_TOKEN }} + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5 + with: + python-version: 3 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - run: | + owner=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f1) + repository=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f2) + + ./bin/examine-pull-request "$owner" "$repository" "${{ secrets.CLA_CHECK_TOKEN }}" "$PULL_REQUEST_NUMBER" cla-database.csv + env: + PULL_REQUEST_NUMBER: ${{ github.event.number }} diff --git a/.snyk b/.snyk new file mode 100644 index 000000000..51293af92 --- /dev/null +++ b/.snyk @@ -0,0 +1,5 @@ +# Snyk (https://snyk.io) policy file + +exclude: + code: + - tests/** diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..a3047deaa --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# Contributing to the Neo4j Ecosystem + +At [Neo4j](https://neo4j.com/), we develop our software in the open at GitHub. +This provides transparency for you, our users, and allows you to fork the software to make your own additions and enhancements. +We also provide areas specifically for community contributions, in particular the [neo4j-contrib](https://github.com/neo4j-contrib) space. + +There's an active [Neo4j Online Community](https://community.neo4j.com/) where we work directly with the community. +If you're not already a member, sign up! + +We love our community and wouldn't be where we are without you. + + +## Need to raise an issue? + +Where you raise an issue depends largely on the nature of the problem. + +Firstly, if you are an Enterprise customer, you might want to head over to our [Customer Support Portal](https://support.neo4j.com/). + +There are plenty of public channels available too, though. +If you simply want to get started or have a question on how to use a particular feature, ask a question in [Neo4j Online Community](https://community.neo4j.com/). +If you think you might have hit a bug in our software (it happens occasionally!) or you have specific feature request then use the issue feature on the relevant GitHub repository. +Check first though as someone else may have already raised something similar. + +[StackOverflow](https://stackoverflow.com/questions/tagged/neo4j) also hosts a ton of questions and might already have a discussion around your problem. +Make sure you have a look there too. + +Include as much information as you can in any request you make: + +- Which versions of our products are you using? +- Which language (and which version of that language) are you developing with? +- What operating system are you on? +- Are you working with a cluster or on a single machine? +- What code are you running? +- What errors are you seeing? +- What solutions have you tried already? + + +## Want to contribute? + +If you want to contribute a pull request, we have a little bit of process you'll need to follow: + +- Do all your work in a personal fork of the original repository +- [Rebase](https://github.com/edx/edx-platform/wiki/How-to-Rebase-a-Pull-Request), don't merge (we prefer to keep our history clean) +- Create a branch (with a useful name) for your contribution +- Make sure you're familiar with the appropriate coding style (this varies by language so ask if you're in doubt) +- Include unit tests if appropriate (obviously not necessary for documentation changes) +- Take a moment to read and sign our [CLA](https://neo4j.com/developer/cla) + +We can't guarantee that we'll accept pull requests and may ask you to make some changes before they go in. +Occasionally, we might also have logistical, commercial, or legal reasons why we can't accept your work but we'll try to find an alternative way for you to contribute in that case. +Remember that many community members have become regular contributors and some are now even Neo employees! + + +## Specifically for this project +Setting up the development environment: + +1. Install Python 3.9.1+ +2. Install poetry (see https://python-poetry.org/docs/#installation) +3. Install dependencies: + +```shell +poetry install +``` + +4. Install the pre-commit hook, that will do some code-format-checking everytime you commit. + +```shell +pre-commit install +``` diff --git a/LICENSE.PYTHON.txt b/LICENSE.PYTHON.txt index 89b3f73c7..7cb919efd 100644 --- a/LICENSE.PYTHON.txt +++ b/LICENSE.PYTHON.txt @@ -20,7 +20,7 @@ analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python alone or in any derivative version, provided, however, that PSF's License Agreement and PSF's notice of copyright, i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, -2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022 Python Software Foundation; +2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. diff --git a/README.md b/README.md index d021e195b..8272e13ca 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,170 @@ # Neo4j GenAI package for Python This repository contains the official Neo4j GenAI features for Python. + +The purpose of this package is to provide a first party package to developers, +where Neo4j can guarantee long term commitment and maintenance as well as being +fast to ship new features and high performing patterns and methods. + +Docs are coming soon! + +# Usage + +## Installation + +This package requires Python (>=3.8.1). + +To install the latest stable version, use: + +```shell +pip install neo4j-genai +``` + +## Examples + +While the library has more retrievers than shown here, the following examples should be able to get you started. + +### Performing a similarity search + +Assumption: Neo4j running with populated vector index in place. + +```python +from neo4j import GraphDatabase +from neo4j_genai import VectorRetriever + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME) + +# Run the similarity search +query_text = "How do I do similarity search in Neo4j?" +response = retriever.search(query_text=query_text, top_k=5) +``` + +### Creating a vector index + +When creating a vector index, make sure you match the number of dimensions in the index with the number of dimensions the embeddings have. + +Assumption: Neo4j running + +```python +from neo4j import GraphDatabase +from neo4j_genai.indexes import create_vector_index + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "chunk-index" + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="textProperty", + dimensions=1536, + similarity_fn="euclidean", +) + +``` + +### Populating the Neo4j Vector Index + +This library does not write to the database, that is up to you. +See below for how to write using Cypher via the Neo4j driver. + +Assumption: Neo4j running with a defined vector index + +```python +from neo4j import GraphDatabase +from random import random + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + +# Upsert the vector +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'textProperty', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) +``` + +# Development + +## Install dependencies + +```bash +poetry install +``` + +## Getting started + +### Issues + +If you have a bug to report or feature to request, first +[search to see if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). +If a related issue doesn't exist, please raise a new issue using the relevant +[issue form](https://github.com/neo4j/neo4j-genai-python/issues/new/choose). + +If you're a Neo4j Enterprise customer, you can also reach out to [Customer Support](http://support.neo4j.com/). + +If you don't have a bug to report or feature request, but you need a hand with +the library; community support is available via [Neo4j Online Community](https://community.neo4j.com/) +and/or [Discord](https://discord.gg/neo4j). + +### Make changes + +1. Fork the repository. +2. Install Python and Poetry. For more information, see [the development guide](./docs/contributing/DEVELOPING.md). +3. Create a working branch from `main` and start with your changes! + +### Pull request + +When you're finished with your changes, create a pull request, also known as a PR. + +- Ensure that you have [signed the CLA](https://neo4j.com/developer/contributing-code/#sign-cla). +- Ensure that the base of your PR is set to `main`. +- Don't forget to [link your PR to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) + if you are solving one. +- Enable the checkbox to [allow maintainer edits](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/allowing-changes-to-a-pull-request-branch-created-from-a-fork) + so that maintainers can make any necessary tweaks and update your branch for merge. +- Reviewers may ask for changes to be made before a PR can be merged, either using + [suggested changes](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request) + or normal pull request comments. You can apply suggested changes directly through + the UI, and any other changes can be made in your fork and committed to the PR branch. +- As you update your PR and apply changes, mark each conversation as [resolved](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/commenting-on-a-pull-request#resolving-conversations). + +## Run tests + +Open a new virtual environment and then run the tests. + +```bash +poetry shell +pytest +``` + +## Further information + +- [The official Neo4j Python driver](https://github.com/neo4j/neo4j-python-driver) +- [Neo4j GenAI integrations](https://neo4j.com/docs/cypher-manual/current/genai-integrations/) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py new file mode 100644 index 000000000..b0413b0a5 --- /dev/null +++ b/examples/hybrid_search.py @@ -0,0 +1,59 @@ +from neo4j import GraphDatabase + +from random import random +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index, create_fulltext_index +from neo4j_genai.retrievers import HybridRetriever + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +FULLTEXT_INDEX_NAME = "fulltext-index-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random() for _ in range(DIMENSION)] + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) +create_fulltext_index( + driver, FULLTEXT_INDEX_NAME, label="Document", node_properties=["propertyKey"] +) + +# Initialize the retriever +retriever = HybridRetriever(driver, INDEX_NAME, FULLTEXT_INDEX_NAME, embedder) + +# Upsert the query +vector = [random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (n:Document {id: $id})" + "WITH n " + "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" + "RETURN n" +) +parameters = { + "id": 0, + "vector": vector, +} +driver.execute_query(insert_query, parameters) + +# Perform the similarity search for a text query +query_text = "Who are the fremen?" +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/openai_search.py b/examples/openai_search.py index f42f76962..87864bdd9 100644 --- a/examples/openai_search.py +++ b/examples/openai_search.py @@ -34,13 +34,15 @@ # Upsert the query vector = [random() for _ in range(DIMENSION)] + insert_query = ( - "MERGE (n:Document)" + "MERGE (n:Document {id: $id})" "WITH n " "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" "RETURN n" ) parameters = { + "id": 0, "vector": vector, } driver.execute_query(insert_query, parameters) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 3104bbc6e..e282b7d57 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,4 +1,3 @@ -from typing import List from neo4j import GraphDatabase from neo4j_genai import VectorRetriever @@ -16,9 +15,9 @@ driver = GraphDatabase.driver(URI, auth=AUTH) -# Create Embedder object +# Create CustomEmbedder object with the required Embedder type class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] @@ -40,12 +39,13 @@ def embed_query(self, text: str) -> List[float]: # Upsert the query vector = [random() for _ in range(DIMENSION)] insert_query = ( - "MERGE (n:Document)" + "MERGE (n:Document {id: $id})" "WITH n " "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" "RETURN n" ) parameters = { + "id": 0, "vector": vector, } driver.execute_query(insert_query, parameters) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 310456760..4fc4c7ac8 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -30,15 +30,15 @@ # Upsert the vector vector = [random() for _ in range(DIMENSION)] insert_query = ( - "MERGE (n:Document)" + "MERGE (n:Document {id: $id})" "WITH n " "CALL db.create.setNodeVectorProperty(n, 'propertyKey', $vector)" "RETURN n" ) parameters = { + "id": 0, "vector": vector, } - driver.execute_query(insert_query, parameters) # Perform the similarity search for a vector query diff --git a/examples/vector_cypher_retrieval.py b/examples/vector_cypher_retrieval.py new file mode 100644 index 000000000..8ca829e2c --- /dev/null +++ b/examples/vector_cypher_retrieval.py @@ -0,0 +1,67 @@ +from neo4j import GraphDatabase +from neo4j_genai import VectorCypherRetriever + +import random +import string +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index + + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(DIMENSION)] + + +# Generate random strings +def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Initialize the retriever +retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name" +retriever = VectorCypherRetriever(driver, INDEX_NAME, retrieval_query, embedder) + +# Upsert the query +vector = [random.random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (doc:Document {id: $id})" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" +) +parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), +} +driver.execute_query(insert_query, parameters) + +# Perform the search +query_text = "Find me the closest text" +print(retriever.search(query_text=query_text, top_k=1)) diff --git a/pyproject.toml b/pyproject.toml index 0a5fb4ec7..cd4de5905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,21 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + [tool.poetry] name = "neo4j-genai" -version = "0.1.2" +version = "0.1.3" description = "Python package to allow easy integration to Neo4j's GenAI features" authors = ["Neo4j, Inc "] license = "Apache License, Version 2.0" diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index de6038a8e..89c3ad57d 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -1,4 +1,19 @@ -from .retrievers import VectorRetriever +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .retrievers import VectorRetriever, VectorCypherRetriever, HybridRetriever -__all__ = ["VectorRetriever"] + +__all__ = ["VectorRetriever", "VectorCypherRetriever", "HybridRetriever"] diff --git a/src/neo4j_genai/embedder.py b/src/neo4j_genai/embedder.py index a4c7fb23e..4cd695295 100644 --- a/src/neo4j_genai/embedder.py +++ b/src/neo4j_genai/embedder.py @@ -1,17 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod -from .types import EmbeddingVector class Embedder(ABC): """Interface for embedding models.""" @abstractmethod - def embed_query(self, text: str) -> EmbeddingVector: + def embed_query(self, text: str) -> list[float]: """Embed query text. Args: text (str): Text to convert to vector embedding Returns: - EmbeddingVector: A vector embedding. + list[float]: A vector embedding. """ diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 29301d61b..b09901e08 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -1,4 +1,17 @@ -from typing import List +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from neo4j import Driver from pydantic import ValidationError @@ -55,7 +68,7 @@ def create_vector_index( def create_fulltext_index( - driver: Driver, name: str, label: str, node_properties: List[str] + driver: Driver, name: str, label: str, node_properties: list[str] ) -> None: """ This method constructs a Cypher query and executes it @@ -67,7 +80,7 @@ def create_fulltext_index( driver (Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. - node_properties (List[str]): The node properties to create the fulltext index on. + node_properties (list[str]): The node properties to create the fulltext index on. Raises: ValueError: If validation of the input arguments fail. @@ -85,7 +98,7 @@ def create_fulltext_index( raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" ) diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index ef1ab84db..15fe31f12 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -1,25 +1,37 @@ -from typing import List, Optional +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod, ABC +from typing import Optional, Any from pydantic import ValidationError -from neo4j import Driver +from neo4j import Driver, Record from .embedder import Embedder -from .types import SimilaritySearchModel, Neo4jRecord +from .types import ( + SimilaritySearchModel, + VectorSearchRecord, + VectorCypherSearchModel, + HybridModel, +) -class VectorRetriever: +class Retriever(ABC): """ - Provides retrieval methods using vector search over embeddings + Abstract class for Neo4j retrievers """ - def __init__( - self, - driver: Driver, - index_name: str, - embedder: Optional[Embedder] = None, - ) -> None: + def __init__(self, driver: Driver): self.driver = driver - self._verify_version() - self.index_name = index_name - self.embedder = embedder def _verify_version(self) -> None: """ @@ -48,12 +60,36 @@ def _verify_version(self) -> None: "This package only supports Neo4j version 5.18.1 or greater" ) + @abstractmethod + def search(self, *args, **kwargs) -> Any: + pass + + +class VectorRetriever(Retriever): + """ + Provides retrieval method using vector search over embeddings. + If an embedder is provided, it needs to have the required Embedder type. + """ + + def __init__( + self, + driver: Driver, + index_name: str, + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.return_properties = return_properties + self.embedder = embedder + def search( self, - query_vector: Optional[List[float]] = None, + query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> List[Neo4jRecord]: + ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -61,8 +97,7 @@ def search( - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) Args: - name (str): Refers to the unique name of the vector index to query. - query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. @@ -71,7 +106,7 @@ def search( ValueError: If no embedder is provided. Returns: - List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. + list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: validated_data = SimilaritySearchModel( @@ -97,11 +132,20 @@ def search( CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ + + if self.return_properties: + return_properties_cypher = ", ".join( + [f".{prop}" for prop in self.return_properties] + ) + db_query_string += ( + f"RETURN node {{{return_properties_cypher}}} as node, score" + ) + records, _, _ = self.driver.execute_query(db_query_string, parameters) try: return [ - Neo4jRecord(node=record["node"], score=record["score"]) + VectorSearchRecord(node=record["node"], score=record["score"]) for record in records ] except ValidationError as e: @@ -109,3 +153,168 @@ def search( raise ValueError( f"Validation failed while constructing output: {error_details}" ) + + +class VectorCypherRetriever(Retriever): + """ + Provides retrieval method using vector similarity and custom Cypher query. + If an embedder is provided, it needs to have the required Embedder type. + """ + + def __init__( + self, + driver: Driver, + index_name: str, + retrieval_query: str, + embedder: Optional[Embedder] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.retrieval_query = retrieval_query + self.embedder = embedder + + def search( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + query_params: Optional[dict[str, Any]] = None, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None. + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = VectorCypherSearchModel( + index_name=self.index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + query_params=query_params, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + if query_params: + for key, value in query_params.items(): + if key not in parameters: + parameters[key] = value + del parameters["query_params"] + + query_prefix = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + search_query = query_prefix + self.retrieval_query + records, _, _ = self.driver.execute_query(search_query, parameters) + return records + + +class HybridRetriever(Retriever): + def __init__( + self, + driver: Driver, + vector_index_name: str, + fulltext_index_name: str, + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.vector_index_name = vector_index_name + self.fulltext_index_name = fulltext_index_name + self.embedder = embedder + self.return_properties = return_properties + + def search( + self, + query_text: str, + query_vector: Optional[list[float]] = None, + top_k: int = 5, + ) -> list[Record]: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + Both query_vector and query_text can be provided. + If query_vector is provided, then it will be preferred over the embedded query_text + for the vector search. + See the following documentation for more details: + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + - [db.index.fulltext.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_fulltext_querynodes) + Args: + query_text (str): The text to get the closest neighbors of. + query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + Returns: + list[Record]: The results of the search query + """ + try: + validated_data = HybridModel( + vector_index_name=self.vector_index_name, + fulltext_index_name=self.fulltext_index_name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text and not query_vector: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + query_vector = self.embedder.embed_query(query_text) + parameters["query_vector"] = query_vector + + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + ) + + if self.return_properties: + return_properties_cypher = ", ".join( + [f".{prop}" for prop in self.return_properties] + ) + search_query += "YIELD node, score " + search_query += f"RETURN node {{{return_properties_cypher}}} as node, score" + else: + search_query += "RETURN node, score" + + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 91db6db74..5747aeeaf 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,17 +1,28 @@ -from typing import List, Any, Literal, Optional +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator from neo4j import Driver -class Neo4jRecord(BaseModel): +class VectorSearchRecord(BaseModel): node: Any score: float -class EmbeddingVector(BaseModel): - vector: List[float] - - class IndexModel(BaseModel): driver: Any @@ -33,7 +44,7 @@ class VectorIndexModel(IndexModel): class FulltextIndexModel(IndexModel): name: str label: str - node_properties: List[str] + node_properties: list[str] @field_validator("node_properties") def check_node_properties_not_empty(cls, v): @@ -45,7 +56,7 @@ def check_node_properties_not_empty(cls, v): class SimilaritySearchModel(BaseModel): index_name: str top_k: PositiveInt = 5 - query_vector: Optional[List[float]] = None + query_vector: Optional[list[float]] = None query_text: Optional[str] = None @model_validator(mode="before") @@ -59,3 +70,15 @@ def check_query(cls, values): "You must provide exactly one of query_vector or query_text." ) return values + + +class VectorCypherSearchModel(SimilaritySearchModel): + query_params: Optional[dict[str, Any]] = None + + +class HybridModel(BaseModel): + vector_index_name: str + fulltext_index_name: str + query_text: str + top_k: PositiveInt = 5 + query_vector: Optional[list[float]] = None diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..c0199c144 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/conftest.py b/tests/conftest.py index d05db4bde..b0359ec0b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest -from neo4j_genai import VectorRetriever +from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -11,5 +26,20 @@ def driver(): @pytest.fixture @patch("neo4j_genai.VectorRetriever._verify_version") -def retriever(_verify_version_mock, driver): +def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") + + +@pytest.fixture +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def vector_cypher_retriever(_verify_version_mock, driver): + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score + """ + return VectorCypherRetriever(driver, "my-index", retrieval_query) + + +@pytest.fixture +@patch("neo4j_genai.HybridRetriever._verify_version") +def hybrid_retriever(_verify_version_mock, driver): + return HybridRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 2aa87d1fa..c624607d5 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -1,3 +1,18 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from neo4j_genai.indexes import ( @@ -74,7 +89,7 @@ def test_create_fulltext_index_happy_path(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) @@ -101,7 +116,7 @@ def test_create_fulltext_index_ensure_escaping(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( - "CREATE FULLTEXT INDEX $name" + "CREATE FULLTEXT INDEX $name " f"FOR (n:`{label}`) ON EACH " f"[{', '.join(['n.`' + property + '`' for property in text_node_properties])}]" ) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index cb8014e7b..60cd6c789 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -1,7 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from unittest.mock import patch, MagicMock + +from neo4j.exceptions import CypherSyntaxError + from neo4j_genai import VectorRetriever -from neo4j_genai.types import Neo4jRecord +from neo4j_genai.retrievers import VectorCypherRetriever, HybridRetriever +from neo4j_genai.types import VectorSearchRecord def test_vector_retriever_supported_aura_version(driver): @@ -36,14 +55,12 @@ def test_vector_retriever_no_supported_version(driver): @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_happy_path(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], @@ -57,8 +74,6 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): records = retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -68,7 +83,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @patch("neo4j_genai.VectorRetriever._verify_version") @@ -107,18 +122,61 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] + + +@patch("neo4j_genai.VectorRetriever._verify_version") +def test_similarity_search_text_return_properties(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(3)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + + index_name = "my-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + return_properties = ["node-property-1", "node-property-2"] + + retriever = VectorRetriever( + driver, index_name, custom_embeddings, return_properties=return_properties + ) + + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + RETURN node {.node-property-1, .node-property-2} as node, score + """ + + records = retriever.search(query_text=query_text, top_k=top_k) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query.rstrip(), + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + }, + ) + + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] -def test_similarity_search_missing_embedder_for_text(retriever): +def test_vector_retriever_search_missing_embedder_for_text(vector_retriever): query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - retriever.search(query_text=query_text, top_k=top_k) + vector_retriever.search(query_text=query_text, top_k=top_k) -def test_similarity_search_both_text_and_vector(retriever): +def test_vector_retriever_search_both_text_and_vector(vector_retriever): query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -126,7 +184,32 @@ def test_similarity_search_both_text_and_vector(retriever): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - retriever.search( + vector_retriever.search( + query_text=query_text, + query_vector=query_vector, + top_k=top_k, + ) + + +def test_vector_cypher_retriever_search_missing_embedder_for_text( + vector_cypher_retriever, +): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + vector_cypher_retriever.search(query_text=query_text, top_k=top_k) + + +def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retriever): + query_text = "may thy knife chip and shatter" + query_vector = [1.1, 2.2, 3.3] + top_k = 5 + + with pytest.raises( + ValueError, match="You must provide exactly one of query_vector or query_text." + ): + vector_cypher_retriever.search( query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -135,14 +218,12 @@ def test_similarity_search_both_text_and_vector(retriever): @patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_bad_results(_verify_version_mock, driver): - custom_embeddings = MagicMock() - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever = VectorRetriever(driver, index_name) retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], @@ -157,8 +238,6 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) - custom_embeddings.embed_query.assert_not_called() - retriever.driver.execute_query.assert_called_once_with( search_query, { @@ -167,3 +246,300 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): "query_vector": query_vector, }, ) + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_retrieval_query_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score + """ + retriever = VectorCypherRetriever( + driver, index_name, retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + driver.execute_query.assert_called_once_with( + search_query + retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + }, + ) + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_retrieval_query_with_params(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + retrieval_query = """ + RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata + """ + query_params = { + "param": "dummy-param", + } + retriever = VectorCypherRetriever( + driver, + index_name, + retrieval_query, + embedder=custom_embeddings, + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.return_value = [ + [{"node_id": 123, "text": "dummy-text", "score": 1.0}], + None, + None, + ] + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + + records = retriever.search( + query_text=query_text, + top_k=top_k, + query_params=query_params, + ) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query + retrieval_query, + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + "param": "dummy-param", + }, + ) + + assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] + + +@patch("neo4j_genai.VectorCypherRetriever._verify_version") +def test_retrieval_query_cypher_error(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + index_name = "my-index" + retrieval_query = """ + this is not a cypher query + """ + retriever = VectorCypherRetriever( + driver, index_name, retrieval_query, embedder=custom_embeddings + ) + query_text = "may thy knife chip and shatter" + top_k = 5 + driver.execute_query.side_effect = CypherSyntaxError + + with pytest.raises(CypherSyntaxError): + retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_text_happy_path(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + + records = retriever.search(query_text=query_text, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + custom_embeddings.embed_query.assert_called_once_with(query_text) + assert records == [{"node": "dummy-node", "score": 1.0}] + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_search_favors_query_vector_over_embedding_vector( + _verify_version_mock, driver +): + embed_query_vector = [1.0 for _ in range(1536)] + query_vector = [2.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + + retriever = HybridRetriever( + driver, vector_index_name, fulltext_index_name, custom_embeddings + ) + + retriever.driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + + retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) + + retriever.driver.execute_query.assert_called_once_with( + search_query, + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": query_vector, + }, + ) + custom_embeddings.embed_query.assert_not_called() + + +def test_error_when_hybrid_search_only_text_no_embedder(hybrid_retriever): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query."): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +def test_hybrid_search_retriever_search_missing_embedder_for_text( + hybrid_retriever, +): + query_text = "may thy knife chip and shatter" + top_k = 5 + + with pytest.raises(ValueError, match="Embedding method required for text query"): + hybrid_retriever.search( + query_text=query_text, + top_k=top_k, + ) + + +@patch("neo4j_genai.HybridRetriever._verify_version") +def test_hybrid_retriever_return_properties(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(1536)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + vector_index_name = "my-index" + fulltext_index_name = "my-fulltext-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + return_properties = ["node-property-1", "node-property-2"] + retriever = HybridRetriever( + driver, + vector_index_name, + fulltext_index_name, + custom_embeddings, + return_properties, + ) + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + search_query = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "YIELD node, score " + "RETURN node {.node-property-1, .node-property-2} as node, score" + ) + + records = retriever.search(query_text=query_text, top_k=top_k) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query.rstrip(), + { + "vector_index_name": vector_index_name, + "top_k": top_k, + "query_text": query_text, + "fulltext_index_name": fulltext_index_name, + "query_vector": embed_query_vector, + }, + ) + + assert records == [{"node": "dummy-node", "score": 1.0}]