diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 40c094922..3efc6f19a 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -50,4 +50,4 @@ def embed_query(self, text: str) -> List[float]: # Perform the similarity search for a text query query_text = "hello world" -print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5)) +print(client.search_similar_vectors(INDEX_NAME, query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 6c4c9c3aa..a25aa6d98 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -40,4 +40,4 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5)) +print(client.search_similar_vectors(INDEX_NAME, query_vector=query_vector, top_k=5)) diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index 77ff98603..73786e9bc 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -1,8 +1,13 @@ -from typing import List, Optional, Any +from typing import List, Optional, Any, Dict from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder -from .types import CreateIndexModel, SimilaritySearchModel +from .types import ( + CreateIndexModel, + SimilaritySearchModel, + Neo4jRecord, + CustomSimilaritySearchModel, +) class GenAIClient: @@ -41,7 +46,6 @@ def _verify_version(self) -> None: version_tuple = tuple(map(int, version.split("."))) target_version = (5, 18, 1) - if version_tuple < target_version: raise ValueError( "This package only supports Neo4j version 5.18.1 or greater" @@ -73,13 +77,15 @@ def create_index( ValueError: If validation of the input arguments fail. """ try: - CreateIndexModel(**{ - "name": name, - "label": label, - "property": property, - "dimensions": dimensions, - "similarity_fn": similarity_fn, - }) + CreateIndexModel( + **{ + "name": name, + "label": label, + "property": property, + "dimensions": dimensions, + "similarity_fn": similarity_fn, + } + ) except ValidationError as e: raise ValueError(f"Error for inputs to create_index {str(e)}") @@ -87,7 +93,10 @@ def create_index( f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS " "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" ) - self.driver.execute_query(query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}) + self.driver.execute_query( + query, + {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, + ) def drop_index(self, name: str) -> None: """ @@ -104,13 +113,12 @@ def drop_index(self, name: str) -> None: } self.driver.execute_query(query, parameters) - def similarity_search( + def search_similar_vectors( self, name: str, query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - custom_retrieval_query: Optional[str] = None, ) -> Any: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -123,7 +131,6 @@ def similarity_search( query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. - custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None Raises: ValueError: If validation of the input arguments fail. @@ -139,31 +146,92 @@ def similarity_search( top_k=top_k, query_vector=query_vector, query_text=query_text, - custom_retrieval_query=custom_retrieval_query, ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") + + parameters = validated_data.model_dump(exclude_none=True) + + if query_text: + if not self.embedder: + raise ValueError("Embedding method required for text query.") + parameters["query_vector"] = self.embedder.embed_query(query_text) + del parameters["query_text"] + + search_query = """ + CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) + YIELD node, score + """ + records, _, _ = self.driver.execute_query(search_query, parameters) + + try: + return [ + Neo4jRecord(node=record["node"], score=record["score"]) + for record in records + ] except ValidationError as e: error_details = e.errors() - raise ValueError(f"Validation failed: {error_details}") + raise ValueError( + f"Validation failed while constructing output: {error_details}" + ) + + def custom_search_similar_vectors( + self, + name: str, + custom_retrieval_query: str, + query_vector: Optional[List[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + custom_params: Optional[Dict[str, Any]] = None, + ) -> Any: + """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. + See the following documentation for more details: + + - [Query a vector index](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-query) + - [db.index.vector.queryNodes()](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_queryNodes) + + Args: + name (str): Refers to the unique name of the vector index to query. + query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. + query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. + top_k (int, optional): The number of neighbors to return. Defaults to 5. + custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None + custom_params (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None + + Raises: + ValueError: If validation of the input arguments fail. + ValueError: If no embedder is provided. + + Returns: + Any: The `top_k` neighbors found in vector search with their nodes and scores. + If custom_retrieval_query is provided, this is changed. + """ + try: + validated_data = CustomSimilaritySearchModel( + index_name=name, + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + custom_retrieval_query=custom_retrieval_query, + custom_params=custom_params, + ) + except ValidationError as e: + raise ValueError(f"Validation failed: {e.errors()}") parameters = validated_data.model_dump(exclude_none=True) if query_text: if not self.embedder: raise ValueError("Embedding method required for text query.") - query_vector = self.embedder.embed_query(query_text) - parameters["query_vector"] = query_vector + parameters["query_vector"] = self.embedder.embed_query(query_text) del parameters["query_text"] query_prefix = """ CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ - - if parameters.get("custom_retrieval_query") is not None: - search_query = query_prefix + parameters["custom_retrieval_query"] - del parameters["custom_retrieval_query"] - else: - search_query = query_prefix + search_query = query_prefix + parameters["custom_retrieval_query"] + del parameters["custom_retrieval_query"] records, _, _ = self.driver.execute_query(search_query, parameters) return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index e9fc80d4f..4ded24db8 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,7 +1,12 @@ -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Any, Dict from pydantic import BaseModel, PositiveInt, model_validator +class Neo4jRecord(BaseModel): + node: Any + score: float + + class EmbeddingVector(BaseModel): vector: List[float] @@ -19,10 +24,9 @@ class SimilaritySearchModel(BaseModel): top_k: PositiveInt = 5 query_vector: Optional[List[float]] = None query_text: Optional[str] = None - custom_retrieval_query: Optional[str] = None @model_validator(mode="before") - def check_query(cls, values): + def check_only_either_vector_or_text(cls, values): """ Validates that one of either query_vector or query_text is provided exclusively. """ @@ -32,3 +36,19 @@ def check_query(cls, values): "You must provide exactly one of query_vector or query_text." ) return values + + +class CustomSimilaritySearchModel(SimilaritySearchModel): + custom_retrieval_query: str + custom_params: Optional[Dict[str, Any]] = None + + @model_validator(mode="before") + def combine_custom_params(cls, values): + """ + Combine custom_params dict into the main model's fields. + """ + custom_params = values.pop("custom_params", None) or {} + for key, value in custom_params.items(): + if key not in values: + values[key] = value + return values diff --git a/tests/test_client.py b/tests/test_client.py index faab9da18..57328c03c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,8 @@ from unittest.mock import patch, MagicMock from neo4j.exceptions import CypherSyntaxError +from neo4j_genai.types import Neo4jRecord + def test_genai_client_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] @@ -16,9 +18,7 @@ def test_genai_client_no_supported_aura_version(driver): with pytest.raises(ValueError) as excinfo: GenAIClient(driver=driver) - assert "This package only supports Neo4j version 5.18.1 or greater" in str( - excinfo - ) + assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) def test_genai_client_supported_version(driver): @@ -33,28 +33,41 @@ def test_genai_client_no_supported_version(driver): with pytest.raises(ValueError) as excinfo: GenAIClient(driver=driver) - assert "This package only supports Neo4j version 5.18.1 or greater" in str( - excinfo - ) + assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) def test_create_index_happy_path(driver, client): driver.execute_query.return_value = [None, None, None] - create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }") + create_query = ( + "CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" + ) client.create_index("my-index", "People", "name", 2048, "cosine") - driver.execute_query.assert_called_once_with(create_query, {"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"}) + driver.execute_query.assert_called_once_with( + create_query, + {"name": "my-index", "dimensions": 2048, "similarity_fn": "cosine"}, + ) + def test_create_index_ensure_escaping(driver, client): driver.execute_query.return_value = [None, None, None] - create_query = ("CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }") + create_query = ( + "CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" + ) client.create_index("my-complicated-`-index", "People", "name", 2048, "cosine") - driver.execute_query.assert_called_once_with(create_query, {"name": "my-complicated-`-index", "dimensions": 2048, "similarity_fn": "cosine"}) + driver.execute_query.assert_called_once_with( + create_query, + { + "name": "my-complicated-`-index", + "dimensions": 2048, + "similarity_fn": "cosine", + }, + ) def test_create_index_validation_error_dimensions_negative_integer(client): @@ -108,7 +121,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = client.similarity_search( + records = client.search_similar_vectors( name=index_name, query_vector=query_vector, top_k=top_k ) @@ -123,7 +136,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): }, ) - assert records == [{"node": "dummy-node", "score": 1.0}] + assert records == [Neo4jRecord(node="dummy-node", score=1.0)] @patch("neo4j_genai.GenAIClient._verify_version") @@ -149,7 +162,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = client.similarity_search( + records = client.search_similar_vectors( name=index_name, query_text=query_text, top_k=top_k ) @@ -164,7 +177,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): }, ) - assert records == [{"node": "dummy-node", "score": 1.0}] + assert records == [Neo4jRecord(node="dummy-node", score=1.0)] def test_similarity_search_missing_embedder_for_text(client): @@ -173,7 +186,9 @@ def test_similarity_search_missing_embedder_for_text(client): top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - client.similarity_search(name=index_name, query_text=query_text, top_k=top_k) + client.search_similar_vectors( + name=index_name, query_text=query_text, top_k=top_k + ) def test_similarity_search_both_text_and_vector(client): @@ -185,7 +200,7 @@ def test_similarity_search_both_text_and_vector(client): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - client.similarity_search( + client.search_similar_vectors( name=index_name, query_text=query_text, query_vector=query_vector, @@ -211,7 +226,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): """ with pytest.raises(ValueError): - client.similarity_search( + client.search_similar_vectors( name=index_name, query_vector=query_vector, top_k=top_k ) @@ -253,7 +268,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver): RETURN node.id as node_id, node.text as text, score """ - records = client.similarity_search( + records = client.custom_search_similar_vectors( name=index_name, query_text=query_text, top_k=top_k, @@ -293,7 +308,7 @@ def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver): """ with pytest.raises(CypherSyntaxError): - client.similarity_search( + client.custom_search_similar_vectors( name=index_name, query_text=query_text, top_k=top_k,