From 30627781c0107051de6e7ab7a5d85fbb8c1da54d Mon Sep 17 00:00:00 2001 From: Will Tai Date: Wed, 1 May 2024 11:17:02 +0100 Subject: [PATCH] Added E2E tests, new GitHub workflow, and separated out unit tests Setup neo4j db for e2e tests --- .github/workflows/pr-e2e-tests.yaml | 52 ++++++++++ .github/workflows/pr.yaml | 4 +- .gitignore | 1 + README.md | 4 +- examples/hybrid_cypher_search.py | 2 +- examples/hybrid_search.py | 2 +- examples/openai_search.py | 2 +- examples/similarity_search_for_text.py | 2 +- examples/vector_cypher_retrieval.py | 2 +- src/neo4j_genai/indexes.py | 2 +- src/neo4j_genai/neo4j_queries.py | 4 +- tests/e2e/conftest.py | 89 +++++++++++++++++ tests/e2e/test_hybrid_e2e.py | 109 +++++++++++++++++++++ tests/e2e/test_vector_e2e.py | 83 ++++++++++++++++ tests/unit/__init__.py | 14 +++ tests/{ => unit}/conftest.py | 8 +- tests/unit/retrievers/__init__.py | 14 +++ tests/{ => unit}/retrievers/test_base.py | 0 tests/{ => unit}/retrievers/test_hybrid.py | 0 tests/{ => unit}/retrievers/test_vector.py | 0 tests/{ => unit}/test_indexes.py | 2 +- tests/{ => unit}/test_neo4j_queries.py | 10 +- 22 files changed, 384 insertions(+), 22 deletions(-) create mode 100644 .github/workflows/pr-e2e-tests.yaml create mode 100644 tests/e2e/conftest.py create mode 100644 tests/e2e/test_hybrid_e2e.py create mode 100644 tests/e2e/test_vector_e2e.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/conftest.py (91%) create mode 100644 tests/unit/retrievers/__init__.py rename tests/{ => unit}/retrievers/test_base.py (100%) rename tests/{ => unit}/retrievers/test_hybrid.py (100%) rename tests/{ => unit}/retrievers/test_vector.py (100%) rename tests/{ => unit}/test_indexes.py (98%) rename tests/{ => unit}/test_neo4j_queries.py (92%) diff --git a/.github/workflows/pr-e2e-tests.yaml b/.github/workflows/pr-e2e-tests.yaml new file mode 100644 index 000000000..05213662b --- /dev/null +++ b/.github/workflows/pr-e2e-tests.yaml @@ -0,0 +1,52 @@ +name: 'Neo4j-GenAI PR E2E Tests' + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + e2e-tests: + runs-on: ubuntu-latest + strategy: + matrix: + neo4j-version: + - 5 + - 5.18.1 + neo4j-edition: + - community + - enterprise + services: + neo4j: + image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }} + env: + NEO4J_AUTH: neo4j/password + NEO4J_ACCEPT_LICENSE_AGREEMENT: yes + ports: + - 7687:7687 + - 7474:7474 + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Configure Poetry + run: | + echo "$HOME/.local/bin" >> $GITHUB_PATH + poetry config virtualenvs.create false + + - name: Install dependencies + run: poetry install + + - name: Run tests + run: poetry run pytest ./tests/e2e diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index eb1606865..f93110e38 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -28,7 +28,7 @@ jobs: run: | poetry run ruff format --check . poetry run ruff check . - - name: Run tests and check coverage + - name: Run unit tests and check coverage run: | - poetry run coverage run -m pytest + poetry run coverage run -m pytest tests/unit poetry run coverage report --fail-under=90 diff --git a/.gitignore b/.gitignore index e9cf65c65..16c7907ed 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ htmlcov/ .idea/ .env +docs/build/ diff --git a/README.md b/README.md index 8272e13ca..e35c41947 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ create_vector_index( ### Populating the Neo4j Vector Index -This library does not write to the database, that is up to you. +This library does not write to the database, that is up to you. See below for how to write using Cypher via the Neo4j driver. Assumption: Neo4j running with a defined vector index @@ -161,7 +161,7 @@ Open a new virtual environment and then run the tests. ```bash poetry shell -pytest +pytest/unit ``` ## Further information diff --git a/examples/hybrid_cypher_search.py b/examples/hybrid_cypher_search.py index a121b2eb4..9bf8ca231 100644 --- a/examples/hybrid_cypher_search.py +++ b/examples/hybrid_cypher_search.py @@ -58,5 +58,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 704a3841a..f035c28c2 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -55,5 +55,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "Who are the fremen?" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/openai_search.py b/examples/openai_search.py index 87864bdd9..f5e6760d8 100644 --- a/examples/openai_search.py +++ b/examples/openai_search.py @@ -48,5 +48,5 @@ driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index e282b7d57..203c2965e 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -51,5 +51,5 @@ def embed_query(self, text: str) -> list[float]: driver.execute_query(insert_query, parameters) # Perform the similarity search for a text query -query_text = "hello world" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/vector_cypher_retrieval.py b/examples/vector_cypher_retrieval.py index 8ca829e2c..63a818607 100644 --- a/examples/vector_cypher_retrieval.py +++ b/examples/vector_cypher_retrieval.py @@ -63,5 +63,5 @@ def random_str(n: int) -> str: driver.execute_query(insert_query, parameters) # Perform the search -query_text = "Find me the closest text" +query_text = "Find me a book about Fremen" print(retriever.search(query_text=query_text, top_k=1)) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index b09901e08..132cc1448 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -115,7 +115,7 @@ def drop_index(driver: Driver, name: str) -> None: driver (Driver): Neo4j Python driver instance. name (str): The name of the index to delete. """ - query = "DROP INDEX $name" + query = "DROP INDEX $name IF EXISTS" parameters = { "name": name, } diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 752677d5f..86dec75d7 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -25,6 +25,7 @@ def get_search_query( query_map = { SearchType.VECTOR: ( "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "YIELD node, score " ), SearchType.HYBRID: ( "CALL { " @@ -48,9 +49,8 @@ def get_search_query( additional_query += retrieval_query elif return_properties: return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties]) - additional_query += "YIELD node, score " additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score" - else: + elif search_type == SearchType.HYBRID: additional_query += "RETURN node, score" return base_query + additional_query diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 000000000..96505f938 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,89 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import string +import random + +import pytest +from neo4j import GraphDatabase +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index + + +@pytest.fixture(scope="module") +def driver(): + uri = "neo4j://localhost:7687" + auth = ("neo4j", "password") + driver = GraphDatabase.driver(uri, auth=auth) + yield driver + driver.close() + + +@pytest.fixture(scope="module") +def custom_embedder(): + class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(1536)] + + return CustomEmbedder() + + +@pytest.fixture(scope="module") +def setup_neo4j(driver): + vector_index_name = "vector-index-name" + fulltext_index_name = "fulltext-index-name" + + # Delete data and drop indexes to prevent data leakage + driver.execute_query("MATCH (n) DETACH DELETE n") + drop_index(driver, vector_index_name) + drop_index(driver, fulltext_index_name) + + # Create a vector index + create_vector_index( + driver, + vector_index_name, + label="Document", + property="propertyKey", + dimensions=1536, + similarity_fn="euclidean", + ) + + # Create a fulltext index + create_fulltext_index( + driver, fulltext_index_name, label="Document", node_properties=["propertyKey"] + ) + + # Insert 10 vectors and authors + vector = [random.random() for _ in range(1536)] + + def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + for i in range(10): + insert_query = ( + "MERGE (doc:Document {id: $id})" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" + ) + + parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), + } + driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py new file mode 100644 index 000000000..431f82761 --- /dev/null +++ b/tests/e2e/test_hybrid_e2e.py @@ -0,0 +1,109 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import Record + +from neo4j_genai import ( + HybridRetriever, + HybridCypherRetriever, +) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_text(driver, custom_embedder): + retriever = HybridRetriever( + driver, "vector-index-name", "fulltext-index-name", custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + custom_embedder, + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_retriever_search_vector(driver): + retriever = HybridRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, Record) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_hybrid_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = HybridCypherRetriever( + driver, + "vector-index-name", + "fulltext-index-name", + retrieval_query, + ) + + top_k = 5 + results = retriever.search( + query_text="Find me a book about Fremen", + query_vector=[1.0 for _ in range(1536)], + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py new file mode 100644 index 000000000..1f8e2abb7 --- /dev/null +++ b/tests/e2e/test_vector_e2e.py @@ -0,0 +1,83 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from neo4j import Record + +from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.types import VectorSearchRecord + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_text(driver, custom_embedder): + retriever = VectorRetriever(driver, "vector-index-name", custom_embedder) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_text(driver, custom_embedder): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever( + driver, "vector-index-name", retrieval_query, custom_embedder + ) + + top_k = 5 + results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_search_vector(driver): + retriever = VectorRetriever(driver, "vector-index-name") + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for result in results: + assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_cypher_retriever_search_vector(driver): + retrieval_query = ( + "MATCH (node)-[:AUTHORED_BY]->(author:Author) " "RETURN author.name" + ) + retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query) + + top_k = 5 + results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k) + + assert isinstance(results, list) + assert len(results) == 5 + for record in results: + assert isinstance(record, Record) + assert "author.name" in record.keys() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/conftest.py b/tests/unit/conftest.py similarity index 91% rename from tests/conftest.py rename to tests/unit/conftest.py index b0359ec0b..b22e58fc6 100644 --- a/tests/conftest.py +++ b/tests/unit/conftest.py @@ -19,18 +19,18 @@ from unittest.mock import MagicMock, patch -@pytest.fixture +@pytest.fixture(scope="function") def driver(): return MagicMock(spec=Driver) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorRetriever._verify_version") def vector_retriever(_verify_version_mock, driver): return VectorRetriever(driver, "my-index") -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.VectorCypherRetriever._verify_version") def vector_cypher_retriever(_verify_version_mock, driver): retrieval_query = """ @@ -39,7 +39,7 @@ def vector_cypher_retriever(_verify_version_mock, driver): return VectorCypherRetriever(driver, "my-index", retrieval_query) -@pytest.fixture +@pytest.fixture(scope="function") @patch("neo4j_genai.HybridRetriever._verify_version") def hybrid_retriever(_verify_version_mock, driver): return HybridRetriever(driver, "my-index", "my-fulltext-index") diff --git a/tests/unit/retrievers/__init__.py b/tests/unit/retrievers/__init__.py new file mode 100644 index 000000000..c0199c144 --- /dev/null +++ b/tests/unit/retrievers/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/retrievers/test_base.py b/tests/unit/retrievers/test_base.py similarity index 100% rename from tests/retrievers/test_base.py rename to tests/unit/retrievers/test_base.py diff --git a/tests/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py similarity index 100% rename from tests/retrievers/test_hybrid.py rename to tests/unit/retrievers/test_hybrid.py diff --git a/tests/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py similarity index 100% rename from tests/retrievers/test_vector.py rename to tests/unit/retrievers/test_vector.py diff --git a/tests/test_indexes.py b/tests/unit/test_indexes.py similarity index 98% rename from tests/test_indexes.py rename to tests/unit/test_indexes.py index c624607d5..841226845 100644 --- a/tests/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -75,7 +75,7 @@ def test_create_vector_index_validation_error_similarity_fn(driver): def test_drop_index(driver): - drop_query = "DROP INDEX $name" + drop_query = "DROP INDEX $name IF EXISTS" drop_index(driver, "my-index") diff --git a/tests/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py similarity index 92% rename from tests/test_neo4j_queries.py rename to tests/unit/test_neo4j_queries.py index 555388ae2..8a6d3fbf4 100644 --- a/tests/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -20,10 +20,10 @@ def test_vector_search_basic(): expected = ( "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "RETURN node, score" + "YIELD node, score" ) result = get_search_query(SearchType.VECTOR) - assert result == expected + assert result.strip() == expected.strip() def test_hybrid_search_basic(): @@ -42,7 +42,7 @@ def test_hybrid_search_basic(): "RETURN node, score" ) result = get_search_query(SearchType.HYBRID) - assert result == expected + assert result.strip() == expected.strip() def test_vector_search_with_properties(): @@ -53,7 +53,7 @@ def test_vector_search_with_properties(): "RETURN node {.name, .age} as node, score" ) result = get_search_query(SearchType.VECTOR, return_properties=properties) - assert result == expected + assert result.strip() == expected.strip() def test_hybrid_search_with_retrieval_query(): @@ -73,4 +73,4 @@ def test_hybrid_search_with_retrieval_query(): + retrieval_query ) result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) - assert result == expected + assert result.strip() == expected.strip()