Skip to content

Commit

Permalink
Add data validation for vector retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed May 14, 2024
1 parent 6715d6a commit 7b7aef5
Showing 1 changed file with 49 additions and 7 deletions.
56 changes: 49 additions & 7 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
VectorSearchRecord,
VectorSearchModel,
VectorCypherSearchModel,
SearchType,
SearchType, Neo4jDriverModel, EmbedderModel, VectorRetrieverModel,
VectorCypherRetrieverModel,
)
from neo4j_genai.neo4j_queries import get_search_query
import logging
Expand All @@ -44,10 +45,26 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = VectorRetrieverModel(
driver_model=driver_model,
index_name=index_name,
embedder_model=embedder_model,
return_properties=return_properties,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(driver)
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder
self.index_name = validated_data.index_name
self.return_properties = validated_data.return_properties
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)
self._node_label = None
self._embedding_node_property = None
self._embedding_dimension = None
Expand Down Expand Up @@ -138,10 +155,34 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = VectorCypherRetrieverModel(
driver_model=driver_model,
index_name=index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(driver)
self.index_name = index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
self.index_name = validated_data.index_name
self.return_properties = validated_data.return_properties
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)
super().__init__(driver)
self.index_name = validated_data.index_name
self.retrieval_query = validated_data.retrieval_query
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)
self._node_label = None
self._node_embedding_property = None
self._embedding_dimension = None
Expand All @@ -166,6 +207,7 @@ def search(
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]], optional): Parameters for the Cypher query. Defaults to None.
filters (Optional[dict[str, Any]], optional): Filters for metadata pre-filtering.. Defaults to None.
Raises:
ValueError: If validation of the input arguments fail.
Expand Down

0 comments on commit 7b7aef5

Please sign in to comment.