From 987afc456952bc6a603829a3bf317788c670fbc4 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 9 May 2024 17:54:50 +0100 Subject: [PATCH] Add new types for validating inputs to retrievers init --- src/neo4j_genai/retrievers/hybrid.py | 76 +++++++++++++++++++++++----- src/neo4j_genai/retrievers/vector.py | 34 ++----------- src/neo4j_genai/types.py | 64 ++++++++++++++++++++++- tests/unit/retrievers/test_hybrid.py | 15 ++++++ tests/unit/retrievers/test_vector.py | 17 ++++++- 5 files changed, 161 insertions(+), 45 deletions(-) diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index fea96a2d4..c9230f463 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -19,7 +19,15 @@ from neo4j_genai.embedder import Embedder from neo4j_genai.retrievers.base import Retriever -from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel +from neo4j_genai.types import ( + HybridSearchModel, + SearchType, + HybridCypherSearchModel, + Neo4jDriverModel, + EmbedderModel, + HybridRetrieverModel, + HybridCypherRetrieverModel, +) from neo4j_genai.neo4j_queries import get_search_query import logging @@ -35,11 +43,30 @@ def __init__( embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, ) -> None: - super().__init__(driver) - self.vector_index_name = vector_index_name - self.fulltext_index_name = fulltext_index_name - self.embedder = embedder - self.return_properties = return_properties + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = HybridRetrieverModel( + driver_model=driver_model, + vector_index_name=vector_index_name, + fulltext_index_name=fulltext_index_name, + embedder_model=embedder_model, + return_properties=return_properties, + ) + except ValidationError as e: + msg = f"Validation failed: {e.errors()}" + logger.error(msg) + raise ValueError(msg) + + super().__init__(validated_data.driver_model.driver) + self.vector_index_name = validated_data.vector_index_name + self.fulltext_index_name = validated_data.fulltext_index_name + self.return_properties = validated_data.return_properties + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) def search( self, @@ -74,7 +101,9 @@ def search( query_text=query_text, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + msg = f"Validation failed: {e.errors()}" + logger.error(msg) + raise ValueError(msg) parameters = validated_data.model_dump(exclude_none=True) @@ -102,11 +131,30 @@ def __init__( retrieval_query: str, embedder: Optional[Embedder] = None, ) -> None: - super().__init__(driver) - self.vector_index_name = vector_index_name - self.fulltext_index_name = fulltext_index_name - self.retrieval_query = retrieval_query - self.embedder = embedder + try: + driver_model = Neo4jDriverModel(driver=driver) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = HybridCypherRetrieverModel( + driver_model=driver_model, + vector_index_name=vector_index_name, + fulltext_index_name=fulltext_index_name, + retrieval_query=retrieval_query, + embedder_model=embedder_model, + ) + except ValidationError as e: + msg = f"Validation failed: {e.errors()}" + logger.error(msg) + raise ValueError(msg) + + super().__init__(validated_data.driver_model.driver) + self.vector_index_name = validated_data.vector_index_name + self.fulltext_index_name = validated_data.fulltext_index_name + self.retrieval_query = validated_data.retrieval_query + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) def search( self, @@ -144,7 +192,9 @@ def search( query_params=query_params, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + msg = f"Validation failed: {e.errors()}" + logger.error(msg) + raise ValueError(msg) parameters = validated_data.model_dump(exclude_none=True) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 323148012..af3a60685 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -48,17 +48,12 @@ def __init__( self.index_name = index_name self.return_properties = return_properties self.embedder = embedder - self._node_label = None - self._embedding_node_property = None - self._embedding_dimension = None - self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -80,7 +75,7 @@ def search( """ try: validated_data = VectorSearchModel( - vector_index_name=self.index_name, + index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -98,15 +93,7 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query, search_params = get_search_query( - SearchType.VECTOR, - self.return_properties, - node_label=self._node_label, - embedding_node_property=self._embedding_node_property, - embedding_dimension=self._embedding_dimension, - filters=filters, - ) - parameters.update(search_params) + search_query = get_search_query(SearchType.VECTOR, self.return_properties) logger.debug("VectorRetriever Cypher parameters: %s", parameters) logger.debug("VectorRetriever Cypher query: %s", search_query) @@ -142,10 +129,6 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder - self._node_label = None - self._node_embedding_property = None - self._embedding_dimension = None - self._fetch_index_infos() def search( self, @@ -153,7 +136,6 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - filters: Optional[dict[str, Any]] = None, ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -176,7 +158,7 @@ def search( """ try: validated_data = VectorCypherSearchModel( - vector_index_name=self.index_name, + index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -199,15 +181,9 @@ def search( parameters[key] = value del parameters["query_params"] - search_query, search_params = get_search_query( - SearchType.VECTOR, - retrieval_query=self.retrieval_query, - node_label=self._node_label, - embedding_node_property=self._node_embedding_property, - embedding_dimension=self._embedding_dimension, - filters=filters, + search_query = get_search_query( + SearchType.VECTOR, retrieval_query=self.retrieval_query ) - parameters.update(search_params) logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters) logger.debug("VectorCypherRetriever Cypher query: %s", search_query) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 357bb44e6..f13e67483 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -14,7 +14,13 @@ # limitations under the License. from enum import Enum from typing import Any, Literal, Optional -from pydantic import BaseModel, PositiveInt, model_validator, field_validator +from pydantic import ( + BaseModel, + PositiveInt, + model_validator, + field_validator, + ConfigDict, +) import neo4j @@ -93,3 +99,59 @@ class SearchType(str, Enum): VECTOR = "vector" HYBRID = "hybrid" + + +class EmbedderModel(BaseModel): + embedder: Optional[Any] + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("embedder") + def check_embedder(cls, value): + if not hasattr(value, "embed_query") or not callable( + getattr(value, "embed_query", None) + ): + raise ValueError( + "Provided embedder object must have an 'embed_query' callable method." + ) + return value + + +class Neo4jDriverModel(BaseModel): + driver: neo4j.Driver + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("driver") + def check_driver(cls, value): + if not isinstance(value, neo4j.Driver): + raise ValueError("Provided driver needs to be of type neo4j.Driver") + return value + + +class VectorRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + index_name: str + embedder_model: Optional[EmbedderModel] = None + return_properties: Optional[list[str]] = None + + +class VectorCypherRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + index_name: str + retrieval_query: str + embedder_model: Optional[EmbedderModel] = None + + +class HybridRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + vector_index_name: str + fulltext_index_name: str + embedder_model: Optional[EmbedderModel] = None + return_properties: Optional[list[str]] = None + + +class HybridCypherRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + vector_index_name: str + fulltext_index_name: str + retrieval_query: str + embedder_model: Optional[EmbedderModel] = None diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 794868355..dff2ebcff 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -43,6 +43,21 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +def test_hybrid_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + HybridRetriever(driver=driver, vector_index_name=42, fulltext_index_name=42) + + +def test_hybrid_cypher_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + HybridCypherRetriever( + driver=driver, + vector_index_name="my-index", + fulltext_index_name="fulltext-index", + retrieval_query=42, + ) + + @patch("neo4j_genai.HybridRetriever._verify_version") def test_hybrid_search_text_happy_path(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index c3fd1ade5..98a892398 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -18,6 +18,7 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.embedder import Embedder from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType, VectorSearchRecord @@ -28,6 +29,18 @@ def test_vector_retriever_initialization(driver): mock_verify.assert_called_once() +def test_vector_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + VectorRetriever(driver=driver, index_name=42) + + +def test_vector_cypher_retriever_bad_data_validation(driver): + with pytest.raises(ValueError): + VectorCypherRetriever( + driver=driver, index_name="my-index", retrieval_query=42 + ) + + def test_vector_cypher_retriever_initialization(driver): with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify: VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query="") @@ -70,7 +83,7 @@ def test_similarity_search_text_happy_path( _verify_version_mock, _fetch_index_infos, driver ): embed_query_vector = [1.0 for _ in range(1536)] - custom_embeddings = MagicMock() + custom_embeddings = MagicMock(spec=Embedder) custom_embeddings.embed_query.return_value = embed_query_vector index_name = "my-index" query_text = "may thy knife chip and shatter" @@ -104,7 +117,7 @@ def test_similarity_search_text_return_properties( _verify_version_mock, _fetch_index_infos, driver ): embed_query_vector = [1.0 for _ in range(3)] - custom_embeddings = MagicMock() + custom_embeddings = MagicMock(spec=Embedder) custom_embeddings.embed_query.return_value = embed_query_vector index_name = "my-index" query_text = "may thy knife chip and shatter"