diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index 5977ff3c..d702b6cb 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -1,4 +1,4 @@ -from .retrievers import VectorRetriever, GraphRetriever +from .retrievers import VectorRetriever, CypherAugmentedVectorRetriever -__all__ = ["VectorRetriever", "GraphRetriever"] +__all__ = ["VectorRetriever", "CypherAugmentedVectorRetriever"] diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 3464ca81..b2e9c4ba 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -110,7 +110,7 @@ def search( ) -class GraphRetriever(VectorRetriever): +class CypherAugmentedVectorRetriever(VectorRetriever): """ Provides retrieval method using vector similarity and custom Cypher query """ diff --git a/tests/conftest.py b/tests/conftest.py index 68ee8ce2..9fa773ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from neo4j_genai import VectorRetriever, GraphRetriever +from neo4j_genai import VectorRetriever, CypherAugmentedVectorRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -16,9 +16,9 @@ def vector_retriever(_verify_version_mock, driver): @pytest.fixture -@patch("neo4j_genai.GraphRetriever._verify_version") -def graph_retriever(_verify_version_mock, driver): +@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version") +def cyphaug_vector_retriever(_verify_version_mock, driver): custom_retrieval_query = """ RETURN node.id AS node_id, node.text AS text, score """ - return GraphRetriever(driver, "my-index", custom_retrieval_query) + return CypherAugmentedVectorRetriever(driver, "my-index", custom_retrieval_query) diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 8c57283e..bbd00f14 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -4,7 +4,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever -from neo4j_genai.retrievers import GraphRetriever +from neo4j_genai.retrievers import CypherAugmentedVectorRetriever from neo4j_genai.types import Neo4jRecord @@ -137,15 +137,17 @@ def test_vector_retriever_search_both_text_and_vector(vector_retriever): ) -def test_graph_retriever_search_missing_embedder_for_text(graph_retriever): +def test_cyphaug_vector_retriever_search_missing_embedder_for_text( + cyphaug_vector_retriever, +): query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - graph_retriever.search(query_text=query_text, top_k=top_k) + cyphaug_vector_retriever.search(query_text=query_text, top_k=top_k) -def test_graph_retriever_search_both_text_and_vector(graph_retriever): +def test_cyphaug_vector_retriever_search_both_text_and_vector(cyphaug_vector_retriever): query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -153,7 +155,7 @@ def test_graph_retriever_search_both_text_and_vector(graph_retriever): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - graph_retriever.search( + cyphaug_vector_retriever.search( query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -196,7 +198,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): ) -@patch("neo4j_genai.GraphRetriever._verify_version") +@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version") def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() @@ -205,7 +207,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): custom_retrieval_query = """ RETURN node.id AS node_id, node.text AS text, score """ - retriever = GraphRetriever( + retriever = CypherAugmentedVectorRetriever( driver, index_name, custom_retrieval_query, embedder=custom_embeddings ) query_text = "may thy knife chip and shatter" @@ -237,7 +239,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] -@patch("neo4j_genai.GraphRetriever._verify_version") +@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version") def test_custom_retrieval_query_with_params(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() @@ -249,7 +251,7 @@ def test_custom_retrieval_query_with_params(_verify_version_mock, driver): custom_params = { "param": "dummy-param", } - retriever = GraphRetriever( + retriever = CypherAugmentedVectorRetriever( driver, index_name, custom_retrieval_query, @@ -288,7 +290,7 @@ def test_custom_retrieval_query_with_params(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] -@patch("neo4j_genai.GraphRetriever._verify_version") +@patch("neo4j_genai.CypherAugmentedVectorRetriever._verify_version") def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() @@ -297,7 +299,7 @@ def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver): custom_retrieval_query = """ this is not a cypher query """ - retriever = GraphRetriever( + retriever = CypherAugmentedVectorRetriever( driver, index_name, custom_retrieval_query, embedder=custom_embeddings ) query_text = "may thy knife chip and shatter"