Skip to content

Commit

Permalink
Renamed GraphRetriever to CypherAugmentedVectorRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 18, 2024
1 parent 943975f commit 8cc4415
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/neo4j_genai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .retrievers import VectorRetriever, GraphRetriever
from .retrievers import VectorRetriever, CypherAugmentedVectorRetriever


__all__ = ["VectorRetriever", "GraphRetriever"]
__all__ = ["VectorRetriever", "CypherAugmentedVectorRetriever"]
2 changes: 1 addition & 1 deletion src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def search(
)


class GraphRetriever(VectorRetriever):
class CypherAugmentedVectorRetriever(VectorRetriever):
"""
Provides retrieval method using vector similarity and custom Cypher query
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from neo4j_genai import VectorRetriever, GraphRetriever
from neo4j_genai import VectorRetriever, CypherAugmentedVectorRetriever
from neo4j import Driver
from unittest.mock import MagicMock, patch

Expand All @@ -16,9 +16,9 @@ def vector_retriever(_verify_version_mock, driver):


@pytest.fixture
@patch("neo4j_genai.GraphRetriever._verify_version")
def graph_retriever(_verify_version_mock, driver):
@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version")
def cyphaug_vector_retriever(_verify_version_mock, driver):
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
return GraphRetriever(driver, "my-index", custom_retrieval_query)
return CypherAugmentedVectorRetriever(driver, "my-index", custom_retrieval_query)
24 changes: 13 additions & 11 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from neo4j.exceptions import CypherSyntaxError

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import GraphRetriever
from neo4j_genai.retrievers import CypherAugmentedVectorRetriever
from neo4j_genai.types import Neo4jRecord


Expand Down Expand Up @@ -137,23 +137,25 @@ def test_vector_retriever_search_both_text_and_vector(vector_retriever):
)


def test_graph_retriever_search_missing_embedder_for_text(graph_retriever):
def test_cyphaug_vector_retriever_search_missing_embedder_for_text(
cyphaug_vector_retriever,
):
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
graph_retriever.search(query_text=query_text, top_k=top_k)
cyphaug_vector_retriever.search(query_text=query_text, top_k=top_k)


def test_graph_retriever_search_both_text_and_vector(graph_retriever):
def test_cyphaug_vector_retriever_search_both_text_and_vector(cyphaug_vector_retriever):
query_text = "may thy knife chip and shatter"
query_vector = [1.1, 2.2, 3.3]
top_k = 5

with pytest.raises(
ValueError, match="You must provide exactly one of query_vector or query_text."
):
graph_retriever.search(
cyphaug_vector_retriever.search(
query_text=query_text,
query_vector=query_vector,
top_k=top_k,
Expand Down Expand Up @@ -196,7 +198,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
)


@patch("neo4j_genai.GraphRetriever._verify_version")
@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version")
def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
Expand All @@ -205,7 +207,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
retriever = GraphRetriever(
retriever = CypherAugmentedVectorRetriever(
driver, index_name, custom_retrieval_query, embedder=custom_embeddings
)
query_text = "may thy knife chip and shatter"
Expand Down Expand Up @@ -237,7 +239,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GraphRetriever._verify_version")
@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version")
def test_custom_retrieval_query_with_params(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
Expand All @@ -249,7 +251,7 @@ def test_custom_retrieval_query_with_params(_verify_version_mock, driver):
custom_params = {
"param": "dummy-param",
}
retriever = GraphRetriever(
retriever = CypherAugmentedVectorRetriever(
driver,
index_name,
custom_retrieval_query,
Expand Down Expand Up @@ -288,7 +290,7 @@ def test_custom_retrieval_query_with_params(_verify_version_mock, driver):
assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GraphRetriever._verify_version")
@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version")
def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
Expand All @@ -297,7 +299,7 @@ def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
custom_retrieval_query = """
this is not a cypher query
"""
retriever = GraphRetriever(
retriever = CypherAugmentedVectorRetriever(
driver, index_name, custom_retrieval_query, embedder=custom_embeddings
)
query_text = "may thy knife chip and shatter"
Expand Down

0 comments on commit 8cc4415

Please sign in to comment.