Skip to content

Commit

Permalink
Setup E2E Test pipeline and add E2E tests for vector and hybrid retri…
Browse files Browse the repository at this point in the history
…evers (#24)

* Added E2E tests, new GitHub workflow, and separated out unit tests

Setup neo4j db for e2e tests

* Refactor query tail generation to separate function

---------

Co-authored-by: Oskar Hane <[email protected]>
  • Loading branch information
willtai and oskarhane committed May 7, 2024
1 parent 54ebce7 commit bd8fb20
Show file tree
Hide file tree
Showing 22 changed files with 605 additions and 114 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/pr-e2e-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: 'Neo4j-GenAI PR E2E Tests'

on:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
e2e-tests:
runs-on: ubuntu-latest
strategy:
matrix:
neo4j-version:
- 5
neo4j-edition:
- community
- enterprise
services:
neo4j:
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
env:
NEO4J_AUTH: neo4j/password
NEO4J_ACCEPT_LICENSE_AGREEMENT: yes
ports:
- 7687:7687
- 7474:7474

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Configure Poetry
run: |
echo "$HOME/.local/bin" >> $GITHUB_PATH
poetry config virtualenvs.create false
- name: Install dependencies
run: poetry install

- name: Run tests
run: poetry run pytest ./tests/e2e
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
poetry run ruff format --check .
poetry run ruff check .
- name: Run tests and check coverage
- name: Run unit tests and check coverage
run: |
poetry run coverage run -m pytest
poetry run coverage run -m pytest tests/unit
poetry run coverage report --fail-under=90
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Open a new virtual environment and then run the tests.

```bash
poetry shell
pytest
pytest tests/unit
```

## Further information
Expand Down
2 changes: 1 addition & 1 deletion examples/hybrid_cypher_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ def embed_query(self, text: str) -> list[float]:
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "Who are the fremen?"
query_text = "Find me a book about Fremen"
print(retriever.search(query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion examples/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ def embed_query(self, text: str) -> list[float]:
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "Who are the fremen?"
query_text = "Find me a book about Fremen"
print(retriever.search(query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion examples/openai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "hello world"
query_text = "Find me a book about Fremen"
print(retriever.search(query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def embed_query(self, text: str) -> list[float]:
driver.execute_query(insert_query, parameters)

# Perform the similarity search for a text query
query_text = "hello world"
query_text = "Find me a book about Fremen"
print(retriever.search(query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion examples/vector_cypher_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def random_str(n: int) -> str:
driver.execute_query(insert_query, parameters)

# Perform the search
query_text = "Find me the closest text"
query_text = "Find me a book about Fremen"
print(retriever.search(query_text=query_text, top_k=1))
2 changes: 1 addition & 1 deletion src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def drop_index(driver: Driver, name: str) -> None:
driver (Driver): Neo4j Python driver instance.
name (str): The name of the index to delete.
"""
query = "DROP INDEX $name"
query = "DROP INDEX $name IF EXISTS"
parameters = {
"name": name,
}
Expand Down
57 changes: 33 additions & 24 deletions src/neo4j_genai/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,43 @@ def get_search_query(
retrieval_query: Optional[str] = None,
):
query_map = {
SearchType.VECTOR: (
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) "
SearchType.VECTOR: "".join(
[
"CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ",
"YIELD node, score ",
get_query_tail(retrieval_query, return_properties),
]
),
SearchType.HYBRID: (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"RETURN node, score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
SearchType.HYBRID: "".join(
[
"CALL { ",
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ",
"YIELD node, score ",
"RETURN node, score UNION ",
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ",
"YIELD node, score ",
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max ",
"UNWIND nodes AS n ",
"RETURN n.node AS node, (n.score / max) AS score ",
"} ",
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ",
get_query_tail(
retrieval_query, return_properties, "RETURN node, score"
),
]
),
}
return query_map[search_type]

base_query = query_map[search_type]
additional_query = ""

def get_query_tail(
retrieval_query: Optional[str] = None,
return_properties: Optional[list[str]] = None,
fallback_return: Optional[str] = None,
) -> str:
if retrieval_query:
additional_query += retrieval_query
elif return_properties:
return retrieval_query
if return_properties:
return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties])
additional_query += "YIELD node, score "
additional_query += f"RETURN node {{{return_properties_cypher}}} as node, score"
else:
additional_query += "RETURN node, score"

return base_query + additional_query
return f"RETURN node {{{return_properties_cypher}}} as node, score"
return fallback_return if fallback_return else ""
90 changes: 90 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
import random
import uuid

import pytest
from neo4j import GraphDatabase
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index


@pytest.fixture(scope="module")
def driver():
uri = "neo4j://localhost:7687"
auth = ("neo4j", "password")
driver = GraphDatabase.driver(uri, auth=auth)
yield driver
driver.close()


@pytest.fixture(scope="module")
def custom_embedder():
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(1536)]

return CustomEmbedder()


@pytest.fixture(scope="module")
def setup_neo4j(driver):
vector_index_name = "vector-index-name"
fulltext_index_name = "fulltext-index-name"

# Delete data and drop indexes to prevent data leakage
driver.execute_query("MATCH (n) DETACH DELETE n")
drop_index(driver, vector_index_name)
drop_index(driver, fulltext_index_name)

# Create a vector index
create_vector_index(
driver,
vector_index_name,
label="Document",
property="propertyKey",
dimensions=1536,
similarity_fn="euclidean",
)

# Create a fulltext index
create_fulltext_index(
driver, fulltext_index_name, label="Document", node_properties=["propertyKey"]
)

# Insert 10 vectors and authors
vector = [random.random() for _ in range(1536)]

def random_str(n: int) -> str:
return "".join([random.choice(string.ascii_letters) for _ in range(n)])

for i in range(10):
insert_query = (
"MERGE (doc:Document {id: $id})"
"WITH doc "
"CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)"
"WITH doc "
"MERGE (author:Author {name: $authorName})"
"MERGE (doc)-[:AUTHORED_BY]->(author)"
"RETURN doc, author"
)

parameters = {
"id": str(uuid.uuid4()),
"vector": vector,
"authorName": random_str(10),
}
driver.execute_query(insert_query, parameters)
Loading

0 comments on commit bd8fb20

Please sign in to comment.