diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 7226b4d8c..ef1ab84db 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -13,10 +13,12 @@ class VectorRetriever: def __init__( self, driver: Driver, + index_name: str, embedder: Optional[Embedder] = None, ) -> None: self.driver = driver self._verify_version() + self.index_name = index_name self.embedder = embedder def _verify_version(self) -> None: @@ -48,7 +50,6 @@ def _verify_version(self) -> None: def search( self, - name: str, query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, top_k: int = 5, @@ -74,7 +75,7 @@ def search( """ try: validated_data = SimilaritySearchModel( - index_name=name, + index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, diff --git a/tests/conftest.py b/tests/conftest.py index bc181925c..d05db4bde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,4 +12,4 @@ def driver(): @pytest.fixture @patch("neo4j_genai.VectorRetriever._verify_version") def retriever(_verify_version_mock, driver): - return VectorRetriever(driver) + return VectorRetriever(driver, "my-index") diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index a0d8ada71..cb8014e7b 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -7,14 +7,14 @@ def test_vector_retriever_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") def test_vector_retriever_no_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None] with pytest.raises(ValueError) as excinfo: - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) @@ -22,14 +22,14 @@ def test_vector_retriever_no_supported_aura_version(driver): def test_vector_retriever_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None] - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") def test_vector_retriever_no_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None] with pytest.raises(ValueError) as excinfo: - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) @@ -38,13 +38,13 @@ def test_vector_retriever_no_supported_version(driver): def test_similarity_search_vector_happy_path(_verify_version_mock, driver): custom_embeddings = MagicMock() - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, @@ -55,7 +55,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) + records = retriever.search(query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called() @@ -77,12 +77,12 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, @@ -94,7 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = retriever.search(name=index_name, query_text=query_text, top_k=top_k) + records = retriever.search(query_text=query_text, top_k=top_k) custom_embeddings.embed_query.assert_called_once_with(query_text) @@ -111,16 +111,14 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): def test_similarity_search_missing_embedder_for_text(retriever): - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - retriever.search(name=index_name, query_text=query_text, top_k=top_k) + retriever.search(query_text=query_text, top_k=top_k) def test_similarity_search_both_text_and_vector(retriever): - index_name = "my-index" query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -129,7 +127,6 @@ def test_similarity_search_both_text_and_vector(retriever): ValueError, match="You must provide exactly one of query_vector or query_text." ): retriever.search( - name=index_name, query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -140,13 +137,13 @@ def test_similarity_search_both_text_and_vector(retriever): def test_similarity_search_vector_bad_results(_verify_version_mock, driver): custom_embeddings = MagicMock() - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], None, @@ -158,7 +155,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): """ with pytest.raises(ValueError): - retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) + retriever.search(query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called()