Skip to content

Commit

Permalink
Use routing_=READ for all read queries (#217)
Browse files Browse the repository at this point in the history
* Use routing_=neo4j.RoutingControl.READ for all READ queries

* Update CHANGELOG
  • Loading branch information
stellasia authored Nov 25, 2024
1 parent 99315c4 commit 7ae97dd
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### Changed
- Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor.
- All `READ` queries are now routed to a reader replica (for clusters). This impacts all retrievers, the `Neo4jChunkReader` and `SinglePropertyExactMatchResolver` components.


## 1.2.0

Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/experimental/components/neo4j_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def run(
result, _, _ = self.driver.execute_query(
query,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
chunks = []
for record in result:
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ async def run(self) -> ResolutionStats:
match_query += self.filter_query
stat_query = f"{match_query} RETURN count(entity) as c"
records, _, _ = self.driver.execute_query(
stat_query, database_=self.neo4j_database
stat_query,
database_=self.neo4j_database,
)
number_of_nodes_to_resolve = records[0].get("c")
if number_of_nodes_to_resolve == 0:
Expand Down
9 changes: 7 additions & 2 deletions src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):

def _get_version(self) -> tuple[tuple[int, ...], bool]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()", database_=self.neo4j_database
"CALL dbms.components()",
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
version = records[0]["versions"][0]
# drop everything after the '-' first
Expand Down Expand Up @@ -145,7 +147,10 @@ def _fetch_index_infos(self, vector_index_name: str) -> None:
"options.indexConfig.`vector.dimensions` as dimensions"
)
query_result = self.driver.execute_query(
query, {"index_name": vector_index_name}, database_=self.neo4j_database
query,
{"index_name": vector_index_name},
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
try:
result = query_result.records[0]
Expand Down
5 changes: 4 additions & 1 deletion src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ def get_search_results(
logger.debug("Pinecone Store Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)

return RawSearchResult(records=records)
5 changes: 4 additions & 1 deletion src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def get_search_results(
logger.debug("Qdrant Store Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)

return RawSearchResult(records=records)
5 changes: 4 additions & 1 deletion src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,10 @@ def get_search_results(
logger.debug("Weaviate Store Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)

return RawSearchResult(records=records)
10 changes: 8 additions & 2 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def get_search_results(
logger.debug("HybridRetriever Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
return RawSearchResult(
records=records,
Expand Down Expand Up @@ -358,7 +361,10 @@ def get_search_results(
logger.debug("HybridRetriever Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
return RawSearchResult(
records=records,
Expand Down
4 changes: 3 additions & 1 deletion src/neo4j_graphrag/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def get_search_results(
t2c_query = llm_result.content
logger.debug("Text2CypherRetriever Cypher query: %s", t2c_query)
records, _, _ = self.driver.execute_query(
query_=t2c_query, database_=self.neo4j_database
query_=t2c_query,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
except CypherSyntaxError as e:
raise Text2CypherRetrievalError(
Expand Down
10 changes: 8 additions & 2 deletions src/neo4j_graphrag/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def get_search_results(
logger.debug("VectorRetriever Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
return RawSearchResult(records=records)

Expand Down Expand Up @@ -363,7 +366,10 @@ def get_search_results(
logger.debug("VectorCypherRetriever Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query, parameters, database_=self.neo4j_database
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
return RawSearchResult(
records=records,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/experimental/components/test_neo4j_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def test_neo4j_chunk_reader(driver: Mock) -> None:
driver.execute_query.assert_called_once_with(
"MATCH (c:`Chunk`) RETURN c { .*, embedding: null } as chunk ORDER BY c.index",
database_="mydb",
routing_=neo4j.RoutingControl.READ,
)

assert isinstance(res, TextChunks)
Expand Down Expand Up @@ -75,6 +76,7 @@ async def test_neo4j_chunk_reader_custom_lg_config(driver: Mock) -> None:
driver.execute_query.assert_called_once_with(
"MATCH (c:`Page`) RETURN c { .*, embedding: null } as chunk ORDER BY c.k",
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert isinstance(res, TextChunks)
Expand Down Expand Up @@ -110,6 +112,7 @@ async def test_neo4j_chunk_reader_fetch_embedding(driver: Mock) -> None:
driver.execute_query.assert_called_once_with(
"MATCH (c:`Chunk`) RETURN c { .* } as chunk ORDER BY c.index",
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert isinstance(res, TextChunks)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/retrievers/external/test_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_pinecone_retriever_search_happy_path(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down Expand Up @@ -171,6 +172,7 @@ def test_pinecone_retriever_search_return_properties(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down Expand Up @@ -230,6 +232,7 @@ def test_pinecone_retriever_search_retrieval_query(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/retrievers/external/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_qdrant_retriever_search_happy_path(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down Expand Up @@ -152,6 +153,7 @@ def test_qdrant_retriever_search_return_properties(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down Expand Up @@ -217,6 +219,7 @@ def test_qdrant_retriever_search_retrieval_query(
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/retrievers/external/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_text_search_remote_vector_store_happy_path(driver: MagicMock) -> None:
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_text_search_remote_vector_store_return_properties(driver: MagicMock) ->
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -193,6 +195,7 @@ def test_text_search_remote_vector_store_retrieval_query(driver: MagicMock) -> N
"id_property": "sync_id",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from unittest.mock import MagicMock, patch

import neo4j
import pytest
from neo4j_graphrag.exceptions import (
EmbeddingRequiredError,
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_hybrid_search_text_happy_path(
"query_vector": embed_query_vector,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
embedder.embed_query.assert_called_once_with(query_text)
assert records == RetrieverResult(
Expand Down Expand Up @@ -262,6 +264,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
"query_vector": query_vector,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
)
embedder.embed_query.assert_not_called()

Expand Down Expand Up @@ -344,6 +347,7 @@ def test_hybrid_retriever_return_properties(
"query_vector": embed_query_vector,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -410,6 +414,7 @@ def test_hybrid_cypher_retrieval_query_with_params(
"param": "dummy-param",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from unittest.mock import MagicMock, patch

import neo4j
import pytest
from neo4j.exceptions import CypherSyntaxError, Neo4jError
from neo4j_graphrag.exceptions import (
Expand Down Expand Up @@ -139,7 +140,9 @@ def test_t2c_retriever_happy_path(
retriever.search(query_text=query_text)
llm.invoke.assert_called_once_with(prompt)
driver.execute_query.assert_called_once_with(
query_=t2c_query, database_=neo4j_database
query_=t2c_query,
database_=neo4j_database,
routing_=neo4j.RoutingControl.READ,
)


Expand Down
6 changes: 6 additions & 0 deletions tests/unit/retrievers/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_similarity_search_vector_happy_path(
"query_vector": query_vector,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -182,6 +183,7 @@ def test_similarity_search_text_happy_path(
"query_vector": embed_query_vector,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -234,6 +236,7 @@ def test_similarity_search_text_return_properties(
"query_vector": embed_query_vector,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -397,6 +400,7 @@ def test_retrieval_query_happy_path(
"query_vector": embed_query_vector,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -458,6 +462,7 @@ def test_retrieval_query_with_result_format_function(
"query_vector": embed_query_vector,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)
assert records == RetrieverResult(
items=[
Expand Down Expand Up @@ -520,6 +525,7 @@ def test_retrieval_query_with_params(
"param": "dummy-param",
},
database_=None,
routing_=neo4j.RoutingControl.READ,
)

assert records == RetrieverResult(
Expand Down

0 comments on commit 7ae97dd

Please sign in to comment.