From 6492ba666024702e882dd3f55f6c45d16693b7d4 Mon Sep 17 00:00:00 2001 From: willtai Date: Fri, 3 May 2024 11:08:29 +0100 Subject: [PATCH] Setup E2E Test pipeline and add E2E tests for vector and hybrid retrievers (#24) * Added E2E tests, new GitHub workflow, and separated out unit tests Setup neo4j db for e2e tests * Refactor query tail generation to separate function --------- Co-authored-by: Oskar Hane --- .github/workflows/pr-e2e-tests.yaml | 51 +++++++ .github/workflows/pr.yaml | 4 +- .gitignore | 1 + README.md | 4 +- examples/hybrid_cypher_search.py | 2 +- examples/hybrid_search.py | 2 +- examples/openai_search.py | 2 +- examples/similarity_search_for_text.py | 2 +- examples/vector_cypher_retrieval.py | 2 +- src/neo4j_genai/indexes.py | 2 +- src/neo4j_genai/neo4j_queries.py | 57 ++++---- tests/e2e/conftest.py | 90 ++++++++++++ tests/e2e/test_hybrid_e2e.py | 132 ++++++++++++++++++ tests/e2e/test_vector_e2e.py | 104 ++++++++++++++ tests/test_neo4j_queries.py | 76 ---------- tests/unit/__init__.py | 14 ++ tests/{ => unit}/conftest.py | 8 +- tests/unit/retrievers/__init__.py | 14 ++ tests/{ => unit}/retrievers/test_base.py | 0 tests/{ => unit}/retrievers/test_hybrid.py | 0 tests/{ => unit}/retrievers/test_vector.py | 0 tests/{ => unit}/test_indexes.py | 2 +- tests/unit/test_neo4j_queries.py | 153 +++++++++++++++++++++ 23 files changed, 607 insertions(+), 115 deletions(-) create mode 100644 .github/workflows/pr-e2e-tests.yaml create mode 100644 tests/e2e/conftest.py create mode 100644 tests/e2e/test_hybrid_e2e.py create mode 100644 tests/e2e/test_vector_e2e.py delete mode 100644 tests/test_neo4j_queries.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/conftest.py (91%) create mode 100644 tests/unit/retrievers/__init__.py rename tests/{ => unit}/retrievers/test_base.py (100%) rename tests/{ => unit}/retrievers/test_hybrid.py (100%) rename tests/{ => unit}/retrievers/test_vector.py (100%) rename tests/{ => unit}/test_indexes.py (98%) create mode 100644 tests/unit/test_neo4j_queries.py diff --git a/.github/workflows/pr-e2e-tests.yaml b/.github/workflows/pr-e2e-tests.yaml new file mode 100644 index 00000000..b09a148e --- /dev/null +++ b/.github/workflows/pr-e2e-tests.yaml @@ -0,0 +1,51 @@ +name: 'Neo4j-GenAI PR E2E Tests' + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + e2e-tests: + runs-on: ubuntu-latest + strategy: + matrix: + neo4j-version: + - 5 + neo4j-edition: + - community + - enterprise + services: + neo4j: + image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }} + env: + NEO4J_AUTH: neo4j/password + NEO4J_ACCEPT_LICENSE_AGREEMENT: yes + ports: + - 7687:7687 + - 7474:7474 + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Configure Poetry + run: | + echo "$HOME/.local/bin" >> $GITHUB_PATH + poetry config virtualenvs.create false + + - name: Install dependencies + run: poetry install + + - name: Run tests + run: poetry run pytest ./tests/e2e diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index eb160686..f93110e3 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -28,7 +28,7 @@ jobs: run: | poetry run ruff format --check . poetry run ruff check . - - name: Run tests and check coverage + - name: Run unit tests and check coverage run: | - poetry run coverage run -m pytest + poetry run coverage run -m pytest tests/unit poetry run coverage report --fail-under=90 diff --git a/.gitignore b/.gitignore index e9cf65c6..16c7907e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ htmlcov/ .idea/ .env +docs/build/ diff --git a/README.md b/README.md index 75016337..bc8eb161 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ create_vector_index( ### Populating the Neo4j Vector Index -This library does not write to the database, that is up to you. +This library does not write to the database, that is up to you. See below for how to write using Cypher via the Neo4j driver. Assumption: Neo4j running with a defined vector index @@ -165,7 +165,7 @@ Open a new virtual environment and then run the tests. ```bash poetry shell -pytest +pytest tests/unit ``` ## Further information diff --git a/examples/hybrid_cypher_search.py b/examples/hybrid_cypher_search.py index a121b2eb..9bf8ca23 100644 --- a/examples/hybrid_cypher_search.py +++ b/examples/hybrid_cypher_search.py @@ -58,5 +58,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 704a3841..f035c28c 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -55,5 +55,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/openai_search.py b/examples/openai_search.py index 87864bdd..f5e6760d 100644 --- a/examples/openai_search.py +++ b/examples/openai_search.py @@ -48,5 +48,5 @@ driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index e282b7d5..203c2965 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -51,5 +51,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/vector_cypher_retrieval.py b/examples/vector_cypher_retrieval.py index 8ca829e2..63a81860 100644 --- a/examples/vector_cypher_retrieval.py +++ b/examples/vector_cypher_retrieval.py @@ -63,5 +63,5 @@ def random_str(n: int) -> str: driver.execute_query(insert_query, parameters) # Perform the search -query_text = "Find me the closest text" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=1)) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b09901e0..132cc144 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -115,7 +115,7 @@ def drop_index(driver: Driver, name: str) -> None: driver (Driver): Neo4j Python driver instance. name (str): The name of the index to delete. """ - query = "DROP INDEX $name" + query = "DROP INDEX $name IF EXISTS" parameters = { "name": name, } diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 752677d5..b9ab366a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -23,34 +23,43 @@ def get_search_query( retrieval_query: Optional[str] = None, ): query_map = { - SearchType.VECTOR: ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + SearchType.VECTOR: "".join( + [ + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", + "YIELD node, score ", + get_query_tail(retrieval_query, return_properties), + ] ), - SearchType.HYBRID: ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + SearchType.HYBRID: "".join( + [ + "CALL { ", + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", + "YIELD node, score ", + "RETURN node, score UNION ", + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", + "YIELD node, score ", + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", + "UNWIND nodes AS n ", + "RETURN n.node AS node, (n.score / max) AS score ", + "} ", + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", + get_query_tail( + retrieval_query, return_properties, "RETURN node, score" + ), + ] ), } + return query_map[search_type] - base_query = query_map[search_type] - additional_query = "" +def get_query_tail( + retrieval_query: Optional[str] = None, + return_properties: Optional[list[str]] = None, + fallback_return: Optional[str] = None, +) -> str: if retrieval_query: - additional_query += retrieval_query - elif return_properties: + return retrieval_query + if return_properties: return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) - additional_query += "YIELD node, score " - additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" - else: - additional_query += "RETURN node, score" - - return base_query + additional_query + return f"RETURN node {{{return_properties_cypher}}} as node, score" + return fallback_return if fallback_return else "" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 00000000..64cd6504 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,90 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import string +import random +import uuid + +import pytest +from neo4j import GraphDatabase +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index + + +@pytest.fixture(scope="module") +def driver(): + uri = "neo4j://localhost:7687" + auth = ("neo4j", "password") + driver = GraphDatabase.driver(uri, auth=auth) + yield driver + driver.close() + + +@pytest.fixture(scope="module") +def custom_embedder(): + class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(1536)] + + return CustomEmbedder() + + +@pytest.fixture(scope="module") +def setup_neo4j(driver): + vector_index_name = "vector-index-name" + fulltext_index_name = "fulltext-index-name" + + # Delete data and drop indexes to prevent data leakage + driver.execute_query("MATCH (n) DETACH DELETE n") + drop_index(driver, vector_index_name) + drop_index(driver, fulltext_index_name) + + # Create a vector index + create_vector_index( + driver, + vector_index_name, + label="Document", + property="propertyKey", + dimensions=1536, + similarity_fn="euclidean", + ) + + # Create a fulltext index + create_fulltext_index( + driver, fulltext_index_name, label="Document", node_properties=["propertyKey"] + ) + + # Insert 10 vectors and authors + vector = [random.random() for _ in range(1536)] + + def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + for i in range(10): + insert_query = ( + "MERGE (doc:Document {id: $id})" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" + ) + + parameters = { + "id": str(uuid.uuid4()), + "vector": vector, + "authorName": random_str(10), + } + driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py new file mode 100644 index 00000000..f8f54466 --- /dev/null +++ b/tests/e2e/test_hybrid_e2e.py @@ -0,0 +1,132 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import Record + +from neo4j_genai import ( + HybridRetriever, + HybridCypherRetriever, +) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_text(driver, custom_embedder): + retriever = HybridRetriever( + driver, "vector-index-name", "fulltext-index-name", custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + custom_embedder, + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_vector(driver): + retriever = HybridRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_return_properties(driver): + properties = ["name", "age"] + retriever = HybridRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + return_properties=properties, + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py new file mode 100644 index 00000000..9bf3f5a4 --- /dev/null +++ b/tests/e2e/test_vector_e2e.py @@ -0,0 +1,104 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from neo4j import Record + +from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.types import VectorSearchRecord + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_text(driver, custom_embedder): + retriever = VectorRetriever(driver, "vector-index-name", custom_embedder) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever( + driver, "vector-index-name", retrieval_query, custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_vector(driver): + retriever = VectorRetriever(driver, "vector-index-name") + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query) + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_return_properties(driver): + properties = ["name", "age"] + retriever = VectorRetriever( + driver, + "vector-index-name", + return_properties=properties, + ) + + top_k = 5 + results = retriever.search( + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) diff --git a/tests/test_neo4j_queries.py b/tests/test_neo4j_queries.py deleted file mode 100644 index 555388ae..00000000 --- a/tests/test_neo4j_queries.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from neo4j_genai.neo4j_queries import get_search_query -from neo4j_genai.types import SearchType - - -def test_vector_search_basic(): - expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "RETURN node, score" - ) - result = get_search_query(SearchType.VECTOR) - assert result == expected - - -def test_hybrid_search_basic(): - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node, score" - ) - result = get_search_query(SearchType.HYBRID) - assert result == expected - - -def test_vector_search_with_properties(): - properties = ["name", "age"] - expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node {.name, .age} as node, score" - ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) - assert result == expected - - -def test_hybrid_search_with_retrieval_query(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "RETURN node, score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / max) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - + retrieval_query - ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) - assert result == expected diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..c0199c14 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/conftest.py b/tests/unit/conftest.py similarity index 91% rename from tests/conftest.py rename to tests/unit/conftest.py index b0359ec0..b22e58fc 100644 --- a/tests/conftest.py +++ b/tests/unit/conftest.py @@ -19,18 +19,18 @@ from unittest.mock import MagicMock, patch -@pytest.fixture +@pytest.fixture(scope="function") def driver(): return MagicMock(spec=Driver) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorRetriever._verify_version") def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorCypherRetriever._verify_version") def vector_cypher_retriever(_verify_version_mock, driver): retrieval_query = """ @@ -39,7 +39,7 @@ def vector_cypher_retriever(_verify_version_mock, driver): return VectorCypherRetriever(driver, "my-index", retrieval_query) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.HybridRetriever._verify_version") def hybrid_retriever(_verify_version_mock, driver): return HybridRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/unit/retrievers/__init__.py b/tests/unit/retrievers/__init__.py new file mode 100644 index 00000000..c0199c14 --- /dev/null +++ b/tests/unit/retrievers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/retrievers/test_base.py b/tests/unit/retrievers/test_base.py similarity index 100% rename from tests/retrievers/test_base.py rename to tests/unit/retrievers/test_base.py diff --git a/tests/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py similarity index 100% rename from tests/retrievers/test_hybrid.py rename to tests/unit/retrievers/test_hybrid.py diff --git a/tests/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py similarity index 100% rename from tests/retrievers/test_vector.py rename to tests/unit/retrievers/test_vector.py diff --git a/tests/test_indexes.py b/tests/unit/test_indexes.py similarity index 98% rename from tests/test_indexes.py rename to tests/unit/test_indexes.py index c624607d..84122684 100644 --- a/tests/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -75,7 +75,7 @@ def test_create_vector_index_validation_error_similarity_fn(driver): def test_drop_index(driver): - drop_query = "DROP INDEX $name" + drop_query = "DROP INDEX $name IF EXISTS" drop_index(driver, "my-index") diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py new file mode 100644 index 00000000..3ce7c774 --- /dev/null +++ b/tests/unit/test_neo4j_queries.py @@ -0,0 +1,153 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.types import SearchType + + +def test_vector_search_basic(): + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score" + ) + result = get_search_query(SearchType.VECTOR) + assert result.strip() == expected.strip() + + +def test_hybrid_search_basic(): + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node, score" + ) + result = get_search_query(SearchType.HYBRID) + assert result.strip() == expected.strip() + + +def test_vector_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.VECTOR, return_properties=properties) + assert result.strip() == expected.strip() + + +def test_vector_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " + retrieval_query + ) + result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_hybrid_search_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + + retrieval_query + ) + result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_hybrid_search_with_properties(): + properties = ["name", "age"] + expected = ( + "CALL { " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score UNION " + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + "RETURN n.node AS node, (n.score / max) AS score " + "} " + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + "RETURN node {.name, .age} as node, score" + ) + result = get_search_query(SearchType.HYBRID, return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = retrieval_query + result = get_query_tail(retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_properties(): + properties = ["name", "age"] + expected = "RETURN node {.name, .age} as node, score" + result = get_query_tail(return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_fallback(): + fallback = "HELLO" + expected = fallback + result = get_query_tail(fallback_return=fallback) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_all(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + properties = ["name", "age"] + fallback = "HELLO" + + expected = retrieval_query + result = get_query_tail( + retrieval_query=retrieval_query, + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_no_retrieval_query(): + properties = ["name", "age"] + fallback = "HELLO" + + expected = "RETURN node {.name, .age} as node, score" + result = get_query_tail( + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip()