Skip to content

Commit

Permalink
Adds custom retrieval query to similarity_serach()
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Mar 15, 2024
1 parent 4633cde commit f0dc1d6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
1 change: 0 additions & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def embed_query(self, text: str) -> List[float]:
# Initialize the client
client = GenAIClient(driver, embedder)

client.drop_index(INDEX_NAME)
# Creating the index
client.create_index(
INDEX_NAME,
Expand Down
32 changes: 16 additions & 16 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from neo4j import Driver
from neo4j.exceptions import CypherSyntaxError
from .embedder import Embedder
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import CreateIndexModel, SimilaritySearchModel


class GenAIClient:
Expand Down Expand Up @@ -47,7 +47,7 @@ def _verify_version(self) -> None:

def _database_query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
) -> Any:
"""
This method sends a Cypher query to the connected Neo4j database
and returns the results as a list of dictionaries.
Expand Down Expand Up @@ -133,7 +133,8 @@ def similarity_search(
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
custom_retrieval_query: Optional[str] = None,
) -> Any:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -145,20 +146,23 @@ def similarity_search(
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
custom_retrieval_query=custom_retrieval_query,
)
except ValidationError as e:
error_details = e.errors()
Expand All @@ -173,19 +177,15 @@ def similarity_search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

db_query_string = """
query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records = self._database_query(db_query_string, params=parameters)

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
error_details = e.errors()
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)
if parameters.get("custom_retrieval_query") is not None:
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]
else:
search_query = query_prefix

return self._database_query(search_query, params=parameters)
7 changes: 2 additions & 5 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import List, Any, Literal, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator


Neo4jRecord = dict[str, Any]
"""Type alias for data items returned from Neo4j queries"""


class EmbeddingVector(BaseModel):
vector: List[float]

Expand All @@ -23,6 +19,7 @@ class SimilaritySearchModel(BaseModel):
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None
custom_retrieval_query: Optional[str] = None

@model_validator(mode="before")
def check_query(cls, values):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,55 @@ def test_similarity_search_both_text_and_vector(client):
query_vector=query_vector,
top_k=top_k,
)


def test_custom_retrieval_query_happy_path(driver, client):
index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5
driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
None,
]
search_query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
custom_retrieval_query = "RETURN node.id as node_id, node.text as text, score"

client.similarity_search(
name=index_name,
query_vector=query_vector,
top_k=top_k,
custom_retrieval_query=custom_retrieval_query,
)

driver.execute_query.assert_called_once_with(
search_query_prefix + custom_retrieval_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": query_vector,
},
)


def test_custom_retrieval_invalid_cypher(driver, client):
index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5
driver.execute_query.side_effect = CypherSyntaxError
custom_retrieval_query = "not a cypher query"

with pytest.raises(ValueError) as excinfo:
client.similarity_search(
name=index_name,
query_vector=query_vector,
top_k=top_k,
custom_retrieval_query=custom_retrieval_query,
)

assert "Cypher Statement is not valid" in str(excinfo)

0 comments on commit f0dc1d6

Please sign in to comment.