diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 2653ba66..b1f3becc 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -30,10 +30,12 @@ def __init__( driver: Driver, index_name: str, embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, ) -> None: self.driver = driver self._verify_version() self.index_name = index_name + self.return_properties = return_properties self.embedder = embedder def _verify_version(self) -> None: @@ -111,6 +113,15 @@ def search( CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ + + if self.return_properties: + return_properties_cypher = ", ".join( + [f".{prop}" for prop in self.return_properties] + ) + db_query_string += ( + f"RETURN node {{{return_properties_cypher}}} as node, score" + ) + records, _, _ = self.driver.execute_query(db_query_string, parameters) try: diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 77ce30e8..d6a6acb3 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -129,6 +129,49 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [Neo4jRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._verify_version") +def test_similarity_search_text_return_properties(_verify_version_mock, driver): + embed_query_vector = [1.0 for _ in range(3)] + custom_embeddings = MagicMock() + custom_embeddings.embed_query.return_value = embed_query_vector + + index_name = "my-index" + query_text = "may thy knife chip and shatter" + top_k = 5 + return_properties = ["node-property-1", "node-property-2"] + + retriever = VectorRetriever( + driver, index_name, custom_embeddings, return_properties=return_properties + ) + + driver.execute_query.return_value = [ + [{"node": "dummy-node", "score": 1.0}], + None, + None, + ] + + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + RETURN node {.node-property-1, .node-property-2} as node, score + """ + + records = retriever.search(query_text=query_text, top_k=top_k) + + custom_embeddings.embed_query.assert_called_once_with(query_text) + + driver.execute_query.assert_called_once_with( + search_query.rstrip(), + { + "index_name": index_name, + "top_k": top_k, + "query_vector": embed_query_vector, + }, + ) + + assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + + def test_vector_retriever_search_missing_embedder_for_text(vector_retriever): query_text = "may thy knife chip and shatter" top_k = 5