Skip to content

Commit

Permalink
Separated similarity search and custom similarity search
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Mar 27, 2024
1 parent 537d08f commit 0d5e488
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 50 deletions.
2 changes: 1 addition & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def embed_query(self, text: str) -> List[float]:

# Perform the similarity search for a text query
query_text = "hello world"
print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5))
print(client.search_similar_vectors(INDEX_NAME, query_text=query_text, top_k=5))
2 changes: 1 addition & 1 deletion examples/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@

# Perform the similarity search for a vector query
query_vector = [random() for _ in range(DIMENSION)]
print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5))
print(client.search_similar_vectors(INDEX_NAME, query_vector=query_vector, top_k=5))
116 changes: 92 additions & 24 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from .types import CreateIndexModel, SimilaritySearchModel
from .types import (
CreateIndexModel,
SimilaritySearchModel,
Neo4jRecord,
CustomSimilaritySearchModel,
)


class GenAIClient:
Expand Down Expand Up @@ -41,7 +46,6 @@ def _verify_version(self) -> None:
version_tuple = tuple(map(int, version.split(".")))
target_version = (5, 18, 1)


if version_tuple < target_version:
raise ValueError(
"This package only supports Neo4j version 5.18.1 or greater"
Expand Down Expand Up @@ -73,21 +77,26 @@ def create_index(
ValueError: If validation of the input arguments fail.
"""
try:
CreateIndexModel(**{
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
})
CreateIndexModel(
**{
"name": name,
"label": label,
"property": property,
"dimensions": dimensions,
"similarity_fn": similarity_fn,
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_index {str(e)}")

query = (
f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)
self.driver.execute_query(query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn})
self.driver.execute_query(
query,
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
)

def drop_index(self, name: str) -> None:
"""
Expand All @@ -104,13 +113,12 @@ def drop_index(self, name: str) -> None:
}
self.driver.execute_query(query, parameters)

def similarity_search(
def search_similar_vectors(
self,
name: str,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
custom_retrieval_query: Optional[str] = None,
) -> Any:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -123,7 +131,6 @@ def similarity_search(
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
Expand All @@ -139,31 +146,92 @@ def similarity_search(
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
custom_retrieval_query=custom_retrieval_query,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records, _, _ = self.driver.execute_query(search_query, parameters)

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
error_details = e.errors()
raise ValueError(f"Validation failed: {error_details}")
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)

def custom_search_similar_vectors(
self,
name: str,
custom_retrieval_query: str,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
custom_params: Optional[Dict[str, Any]] = None,
) -> Any:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
- [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query)
- [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes)
Args:
name (str): Refers to the unique name of the vector index to query.
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
custom_params (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
"""
try:
validated_data = CustomSimilaritySearchModel(
index_name=name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
custom_retrieval_query=custom_retrieval_query,
custom_params=custom_params,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

if query_text:
if not self.embedder:
raise ValueError("Embedding method required for text query.")
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""

if parameters.get("custom_retrieval_query") is not None:
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]
else:
search_query = query_prefix
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]

records, _, _ = self.driver.execute_query(search_query, parameters)
return records
26 changes: 23 additions & 3 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Any, Dict
from pydantic import BaseModel, PositiveInt, model_validator


class Neo4jRecord(BaseModel):
node: Any
score: float


class EmbeddingVector(BaseModel):
vector: List[float]

Expand All @@ -19,10 +24,9 @@ class SimilaritySearchModel(BaseModel):
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None
custom_retrieval_query: Optional[str] = None

@model_validator(mode="before")
def check_query(cls, values):
def check_only_either_vector_or_text(cls, values):
"""
Validates that one of either query_vector or query_text is provided exclusively.
"""
Expand All @@ -32,3 +36,19 @@ def check_query(cls, values):
"You must provide exactly one of query_vector or query_text."
)
return values


class CustomSimilaritySearchModel(SimilaritySearchModel):
custom_retrieval_query: str
custom_params: Optional[Dict[str, Any]] = None

@model_validator(mode="before")
def combine_custom_params(cls, values):
"""
Combine custom_params dict into the main model's fields.
"""
custom_params = values.pop("custom_params", None) or {}
for key, value in custom_params.items():
if key not in values:
values[key] = value
return values
57 changes: 36 additions & 21 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from unittest.mock import patch, MagicMock
from neo4j.exceptions import CypherSyntaxError

from neo4j_genai.types import Neo4jRecord


def test_genai_client_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None]
Expand All @@ -16,9 +18,7 @@ def test_genai_client_no_supported_aura_version(driver):
with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)

assert "This package only supports Neo4j version 5.18.1 or greater" in str(
excinfo
)
assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


def test_genai_client_supported_version(driver):
Expand All @@ -33,28 +33,41 @@ def test_genai_client_no_supported_version(driver):
with pytest.raises(ValueError) as excinfo:
GenAIClient(driver=driver)

assert "This package only supports Neo4j version 5.18.1 or greater" in str(
excinfo
)
assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


def test_create_index_happy_path(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }")
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)

client.create_index("my-index", "People", "name", 2048, "cosine")

driver.execute_query.assert_called_once_with(create_query, {"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"})
driver.execute_query.assert_called_once_with(
create_query,
{"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"},
)


def test_create_index_ensure_escaping(driver, client):
driver.execute_query.return_value = [None, None, None]
create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }")
create_query = (
"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS "
"{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }"
)

client.create_index("my-complicated-`-index", "People", "name", 2048, "cosine")

driver.execute_query.assert_called_once_with(create_query, {"name": "my-complicated-`-index", "dimensions": 2048, "similarity_fn": "cosine"})
driver.execute_query.assert_called_once_with(
create_query,
{
"name": "my-complicated-`-index",
"dimensions": 2048,
"similarity_fn": "cosine",
},
)


def test_create_index_validation_error_dimensions_negative_integer(client):
Expand Down Expand Up @@ -108,7 +121,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = client.similarity_search(
records = client.search_similar_vectors(
name=index_name, query_vector=query_vector, top_k=top_k
)

Expand All @@ -123,7 +136,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [{"node": "dummy-node", "score": 1.0}]
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.GenAIClient._verify_version")
Expand All @@ -149,7 +162,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = client.similarity_search(
records = client.search_similar_vectors(
name=index_name, query_text=query_text, top_k=top_k
)

Expand All @@ -164,7 +177,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [{"node": "dummy-node", "score": 1.0}]
assert records == [Neo4jRecord(node="dummy-node", score=1.0)]


def test_similarity_search_missing_embedder_for_text(client):
Expand All @@ -173,7 +186,9 @@ def test_similarity_search_missing_embedder_for_text(client):
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
client.similarity_search(name=index_name, query_text=query_text, top_k=top_k)
client.search_similar_vectors(
name=index_name, query_text=query_text, top_k=top_k
)


def test_similarity_search_both_text_and_vector(client):
Expand All @@ -185,7 +200,7 @@ def test_similarity_search_both_text_and_vector(client):
with pytest.raises(
ValueError, match="You must provide exactly one of query_vector or query_text."
):
client.similarity_search(
client.search_similar_vectors(
name=index_name,
query_text=query_text,
query_vector=query_vector,
Expand All @@ -211,7 +226,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"""

with pytest.raises(ValueError):
client.similarity_search(
client.search_similar_vectors(
name=index_name, query_vector=query_vector, top_k=top_k
)

Expand Down Expand Up @@ -253,7 +268,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
RETURN node.id as node_id, node.text as text, score
"""

records = client.similarity_search(
records = client.custom_search_similar_vectors(
name=index_name,
query_text=query_text,
top_k=top_k,
Expand Down Expand Up @@ -293,7 +308,7 @@ def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
"""

with pytest.raises(CypherSyntaxError):
client.similarity_search(
client.custom_search_similar_vectors(
name=index_name,
query_text=query_text,
top_k=top_k,
Expand Down

0 comments on commit 0d5e488

Please sign in to comment.