From 7ae97ddc89e045d37d4b5f6b907e9852137902f8 Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Mon, 25 Nov 2024 13:33:23 +0100 Subject: [PATCH] Use routing_=READ for all read queries (#217) * Use routing_=neo4j.RoutingControl.READ for all READ queries * Update CHANGELOG --- CHANGELOG.md | 2 ++ .../experimental/components/neo4j_reader.py | 1 + src/neo4j_graphrag/experimental/components/resolver.py | 3 ++- src/neo4j_graphrag/retrievers/base.py | 9 +++++++-- .../retrievers/external/pinecone/pinecone.py | 5 ++++- .../retrievers/external/qdrant/qdrant.py | 5 ++++- .../retrievers/external/weaviate/weaviate.py | 5 ++++- src/neo4j_graphrag/retrievers/hybrid.py | 10 ++++++++-- src/neo4j_graphrag/retrievers/text2cypher.py | 4 +++- src/neo4j_graphrag/retrievers/vector.py | 10 ++++++++-- .../unit/experimental/components/test_neo4j_reader.py | 3 +++ tests/unit/retrievers/external/test_pinecone.py | 3 +++ tests/unit/retrievers/external/test_qdrant.py | 3 +++ tests/unit/retrievers/external/test_weaviate.py | 3 +++ tests/unit/retrievers/test_hybrid.py | 5 +++++ tests/unit/retrievers/test_text2cypher.py | 5 ++++- tests/unit/retrievers/test_vector.py | 6 ++++++ 17 files changed, 70 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c99d49b8..c7977da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ ### Changed - Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor. +- All `READ` queries are now routed to a reader replica (for clusters). This impacts all retrievers, the `Neo4jChunkReader` and `SinglePropertyExactMatchResolver` components. + ## 1.2.0 diff --git a/src/neo4j_graphrag/experimental/components/neo4j_reader.py b/src/neo4j_graphrag/experimental/components/neo4j_reader.py index 8aee5d1c..1cb6ff4a 100644 --- a/src/neo4j_graphrag/experimental/components/neo4j_reader.py +++ b/src/neo4j_graphrag/experimental/components/neo4j_reader.py @@ -97,6 +97,7 @@ async def run( result, _, _ = self.driver.execute_query( query, database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) chunks = [] for record in result: diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 14216640..f2da0bff 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -94,7 +94,8 @@ async def run(self) -> ResolutionStats: match_query += self.filter_query stat_query = f"{match_query} RETURN count(entity) as c" records, _, _ = self.driver.execute_query( - stat_query, database_=self.neo4j_database + stat_query, + database_=self.neo4j_database, ) number_of_nodes_to_resolve = records[0].get("c") if number_of_nodes_to_resolve == 0: diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index dcfe0771..55ae06ef 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -89,7 +89,9 @@ def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None): def _get_version(self) -> tuple[tuple[int, ...], bool]: records, _, _ = self.driver.execute_query( - "CALL dbms.components()", database_=self.neo4j_database + "CALL dbms.components()", + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) version = records[0]["versions"][0] # drop everything after the '-' first @@ -145,7 +147,10 @@ def _fetch_index_infos(self, vector_index_name: str) -> None: "options.indexConfig.`vector.dimensions` as dimensions" ) query_result = self.driver.execute_query( - query, {"index_name": vector_index_name}, database_=self.neo4j_database + query, + {"index_name": vector_index_name}, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) try: result = query_result.records[0] diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py index 58879112..16b3ea3d 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py @@ -234,7 +234,10 @@ def get_search_results( logger.debug("Pinecone Store Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult(records=records) diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py index a38b322f..f0e366aa 100644 --- a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py +++ b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py @@ -228,7 +228,10 @@ def get_search_results( logger.debug("Qdrant Store Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult(records=records) diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py index a6f28b7e..cad21443 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py @@ -245,7 +245,10 @@ def get_search_results( logger.debug("Weaviate Store Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult(records=records) diff --git a/src/neo4j_graphrag/retrievers/hybrid.py b/src/neo4j_graphrag/retrievers/hybrid.py index c1b97442..4634b8a0 100644 --- a/src/neo4j_graphrag/retrievers/hybrid.py +++ b/src/neo4j_graphrag/retrievers/hybrid.py @@ -201,7 +201,10 @@ def get_search_results( logger.debug("HybridRetriever Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult( records=records, @@ -358,7 +361,10 @@ def get_search_results( logger.debug("HybridRetriever Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult( records=records, diff --git a/src/neo4j_graphrag/retrievers/text2cypher.py b/src/neo4j_graphrag/retrievers/text2cypher.py index 8297f123..039f42f0 100644 --- a/src/neo4j_graphrag/retrievers/text2cypher.py +++ b/src/neo4j_graphrag/retrievers/text2cypher.py @@ -167,7 +167,9 @@ def get_search_results( t2c_query = llm_result.content logger.debug("Text2CypherRetriever Cypher query: %s", t2c_query) records, _, _ = self.driver.execute_query( - query_=t2c_query, database_=self.neo4j_database + query_=t2c_query, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) except CypherSyntaxError as e: raise Text2CypherRetrievalError( diff --git a/src/neo4j_graphrag/retrievers/vector.py b/src/neo4j_graphrag/retrievers/vector.py index cfec8365..2f85576d 100644 --- a/src/neo4j_graphrag/retrievers/vector.py +++ b/src/neo4j_graphrag/retrievers/vector.py @@ -207,7 +207,10 @@ def get_search_results( logger.debug("VectorRetriever Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult(records=records) @@ -363,7 +366,10 @@ def get_search_results( logger.debug("VectorCypherRetriever Cypher query: %s", search_query) records, _, _ = self.driver.execute_query( - search_query, parameters, database_=self.neo4j_database + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, ) return RawSearchResult( records=records, diff --git a/tests/unit/experimental/components/test_neo4j_reader.py b/tests/unit/experimental/components/test_neo4j_reader.py index fec8adde..d80bc9a8 100644 --- a/tests/unit/experimental/components/test_neo4j_reader.py +++ b/tests/unit/experimental/components/test_neo4j_reader.py @@ -35,6 +35,7 @@ async def test_neo4j_chunk_reader(driver: Mock) -> None: driver.execute_query.assert_called_once_with( "MATCH (c:`Chunk`) RETURN c { .*, embedding: null } as chunk ORDER BY c.index", database_="mydb", + routing_=neo4j.RoutingControl.READ, ) assert isinstance(res, TextChunks) @@ -75,6 +76,7 @@ async def test_neo4j_chunk_reader_custom_lg_config(driver: Mock) -> None: driver.execute_query.assert_called_once_with( "MATCH (c:`Page`) RETURN c { .*, embedding: null } as chunk ORDER BY c.k", database_=None, + routing_=neo4j.RoutingControl.READ, ) assert isinstance(res, TextChunks) @@ -110,6 +112,7 @@ async def test_neo4j_chunk_reader_fetch_embedding(driver: Mock) -> None: driver.execute_query.assert_called_once_with( "MATCH (c:`Chunk`) RETURN c { .* } as chunk ORDER BY c.index", database_=None, + routing_=neo4j.RoutingControl.READ, ) assert isinstance(res, TextChunks) diff --git a/tests/unit/retrievers/external/test_pinecone.py b/tests/unit/retrievers/external/test_pinecone.py index dd6def4f..9c09b0fb 100644 --- a/tests/unit/retrievers/external/test_pinecone.py +++ b/tests/unit/retrievers/external/test_pinecone.py @@ -99,6 +99,7 @@ def test_pinecone_retriever_search_happy_path( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( @@ -171,6 +172,7 @@ def test_pinecone_retriever_search_return_properties( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( @@ -230,6 +232,7 @@ def test_pinecone_retriever_search_retrieval_query( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( diff --git a/tests/unit/retrievers/external/test_qdrant.py b/tests/unit/retrievers/external/test_qdrant.py index e5b678a6..6b010038 100644 --- a/tests/unit/retrievers/external/test_qdrant.py +++ b/tests/unit/retrievers/external/test_qdrant.py @@ -74,6 +74,7 @@ def test_qdrant_retriever_search_happy_path( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( @@ -152,6 +153,7 @@ def test_qdrant_retriever_search_return_properties( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( @@ -217,6 +219,7 @@ def test_qdrant_retriever_search_retrieval_query( "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( diff --git a/tests/unit/retrievers/external/test_weaviate.py b/tests/unit/retrievers/external/test_weaviate.py index e8b9af7f..af3da75c 100644 --- a/tests/unit/retrievers/external/test_weaviate.py +++ b/tests/unit/retrievers/external/test_weaviate.py @@ -80,6 +80,7 @@ def test_text_search_remote_vector_store_happy_path(driver: MagicMock) -> None: "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -146,6 +147,7 @@ def test_text_search_remote_vector_store_return_properties(driver: MagicMock) -> "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -193,6 +195,7 @@ def test_text_search_remote_vector_store_retrieval_query(driver: MagicMock) -> N "id_property": "sync_id", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 1e11b680..c66bdf08 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch +import neo4j import pytest from neo4j_graphrag.exceptions import ( EmbeddingRequiredError, @@ -204,6 +205,7 @@ def test_hybrid_search_text_happy_path( "query_vector": embed_query_vector, }, database_=None, + routing_=neo4j.RoutingControl.READ, ) embedder.embed_query.assert_called_once_with(query_text) assert records == RetrieverResult( @@ -262,6 +264,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( "query_vector": query_vector, }, database_=database, + routing_=neo4j.RoutingControl.READ, ) embedder.embed_query.assert_not_called() @@ -344,6 +347,7 @@ def test_hybrid_retriever_return_properties( "query_vector": embed_query_vector, }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -410,6 +414,7 @@ def test_hybrid_cypher_retrieval_query_with_params( "param": "dummy-param", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index a23f12f4..4d110c8e 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -14,6 +14,7 @@ # limitations under the License. from unittest.mock import MagicMock, patch +import neo4j import pytest from neo4j.exceptions import CypherSyntaxError, Neo4jError from neo4j_graphrag.exceptions import ( @@ -139,7 +140,9 @@ def test_t2c_retriever_happy_path( retriever.search(query_text=query_text) llm.invoke.assert_called_once_with(prompt) driver.execute_query.assert_called_once_with( - query_=t2c_query, database_=neo4j_database + query_=t2c_query, + database_=neo4j_database, + routing_=neo4j.RoutingControl.READ, ) diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 2d074d34..0fe59cca 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -137,6 +137,7 @@ def test_similarity_search_vector_happy_path( "query_vector": query_vector, }, database_=database, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -182,6 +183,7 @@ def test_similarity_search_text_happy_path( "query_vector": embed_query_vector, }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -234,6 +236,7 @@ def test_similarity_search_text_return_properties( "query_vector": embed_query_vector, }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -397,6 +400,7 @@ def test_retrieval_query_happy_path( "query_vector": embed_query_vector, }, database_=database, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -458,6 +462,7 @@ def test_retrieval_query_with_result_format_function( "query_vector": embed_query_vector, }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult( items=[ @@ -520,6 +525,7 @@ def test_retrieval_query_with_params( "param": "dummy-param", }, database_=None, + routing_=neo4j.RoutingControl.READ, ) assert records == RetrieverResult(