Skip to content

Commit

Permalink
Option to return only specific properties from VectorRetriever.search…
Browse files Browse the repository at this point in the history
… results (#13)

Add `return_properties ` argument to VectorRetriever constructor

To not return the full nodes on plain vector search for a faster and leaner retrieval.
  • Loading branch information
oskarhane authored Apr 22, 2024
1 parent de74a0f commit b1f63a8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def __init__(
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
self.driver = driver
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def _verify_version(self) -> None:
Expand Down Expand Up @@ -111,6 +113,15 @@ def search(
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""

if self.return_properties:
return_properties_cypher = ", ".join(
[f".{prop}" for prop in self.return_properties]
)
db_query_string += (
f"RETURN node {{{return_properties_cypher}}} as node, score"
)

records, _, _ = self.driver.execute_query(db_query_string, parameters)

try:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,49 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
def test_similarity_search_text_return_properties(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(3)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5
return_properties = ["node-property-1", "node-property-2"]

retriever = VectorRetriever(
driver, index_name, custom_embeddings, return_properties=return_properties
)

driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
]

search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
RETURN node {.node-property-1, .node-property-2} as node, score
"""

records = retriever.search(query_text=query_text, top_k=top_k)

custom_embeddings.embed_query.assert_called_once_with(query_text)

driver.execute_query.assert_called_once_with(
search_query.rstrip(),
{
"index_name": index_name,
"top_k": top_k,
"query_vector": embed_query_vector,
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):
query_text = "may thy knife chip and shatter"
top_k = 5
Expand Down

0 comments on commit b1f63a8

Please sign in to comment.