Skip to content

Commit

Permalink
Add new types for validating inputs to retrievers init
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed May 14, 2024
1 parent e30fa26 commit 987afc4
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 45 deletions.
76 changes: 63 additions & 13 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

from neo4j_genai.embedder import Embedder
from neo4j_genai.retrievers.base import Retriever
from neo4j_genai.types import HybridSearchModel, SearchType, HybridCypherSearchModel
from neo4j_genai.types import (
HybridSearchModel,
SearchType,
HybridCypherSearchModel,
Neo4jDriverModel,
EmbedderModel,
HybridRetrieverModel,
HybridCypherRetrieverModel,
)
from neo4j_genai.neo4j_queries import get_search_query
import logging

Expand All @@ -35,11 +43,30 @@ def __init__(
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.embedder = embedder
self.return_properties = return_properties
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = HybridRetrieverModel(
driver_model=driver_model,
vector_index_name=vector_index_name,
fulltext_index_name=fulltext_index_name,
embedder_model=embedder_model,
return_properties=return_properties,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
self.fulltext_index_name = validated_data.fulltext_index_name
self.return_properties = validated_data.return_properties
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)

def search(
self,
Expand Down Expand Up @@ -74,7 +101,9 @@ def search(
query_text=query_text,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)

parameters = validated_data.model_dump(exclude_none=True)

Expand Down Expand Up @@ -102,11 +131,30 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
super().__init__(driver)
self.vector_index_name = vector_index_name
self.fulltext_index_name = fulltext_index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
try:
driver_model = Neo4jDriverModel(driver=driver)
embedder_model = EmbedderModel(embedder=embedder) if embedder else None
validated_data = HybridCypherRetrieverModel(
driver_model=driver_model,
vector_index_name=vector_index_name,
fulltext_index_name=fulltext_index_name,
retrieval_query=retrieval_query,
embedder_model=embedder_model,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
self.fulltext_index_name = validated_data.fulltext_index_name
self.retrieval_query = validated_data.retrieval_query
self.embedder = (
validated_data.embedder_model.embedder
if validated_data.embedder_model
else None
)

def search(
self,
Expand Down Expand Up @@ -144,7 +192,9 @@ def search(
query_params=query_params,
)
except ValidationError as e:
raise ValueError(f"Validation failed: {e.errors()}")
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)

parameters = validated_data.model_dump(exclude_none=True)

Expand Down
34 changes: 5 additions & 29 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,12 @@ def __init__(
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder
self._node_label = None
self._embedding_node_property = None
self._embedding_dimension = None
self._fetch_index_infos()

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
filters: Optional[dict[str, Any]] = None,
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -80,7 +75,7 @@ def search(
"""
try:
validated_data = VectorSearchModel(
vector_index_name=self.index_name,
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand All @@ -98,15 +93,7 @@ def search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

search_query, search_params = get_search_query(
SearchType.VECTOR,
self.return_properties,
node_label=self._node_label,
embedding_node_property=self._embedding_node_property,
embedding_dimension=self._embedding_dimension,
filters=filters,
)
parameters.update(search_params)
search_query = get_search_query(SearchType.VECTOR, self.return_properties)

logger.debug("VectorRetriever Cypher parameters: %s", parameters)
logger.debug("VectorRetriever Cypher query: %s", search_query)
Expand Down Expand Up @@ -142,18 +129,13 @@ def __init__(
self.index_name = index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
self._node_label = None
self._node_embedding_property = None
self._embedding_dimension = None
self._fetch_index_infos()

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
filters: Optional[dict[str, Any]] = None,
) -> list[neo4j.Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -176,7 +158,7 @@ def search(
"""
try:
validated_data = VectorCypherSearchModel(
vector_index_name=self.index_name,
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand All @@ -199,15 +181,9 @@ def search(
parameters[key] = value
del parameters["query_params"]

search_query, search_params = get_search_query(
SearchType.VECTOR,
retrieval_query=self.retrieval_query,
node_label=self._node_label,
embedding_node_property=self._node_embedding_property,
embedding_dimension=self._embedding_dimension,
filters=filters,
search_query = get_search_query(
SearchType.VECTOR, retrieval_query=self.retrieval_query
)
parameters.update(search_params)

logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters)
logger.debug("VectorCypherRetriever Cypher query: %s", search_query)
Expand Down
64 changes: 63 additions & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
# limitations under the License.
from enum import Enum
from typing import Any, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator, field_validator
from pydantic import (
BaseModel,
PositiveInt,
model_validator,
field_validator,
ConfigDict,
)
import neo4j


Expand Down Expand Up @@ -93,3 +99,59 @@ class SearchType(str, Enum):

VECTOR = "vector"
HYBRID = "hybrid"


class EmbedderModel(BaseModel):
embedder: Optional[Any]
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("embedder")
def check_embedder(cls, value):
if not hasattr(value, "embed_query") or not callable(
getattr(value, "embed_query", None)
):
raise ValueError(
"Provided embedder object must have an 'embed_query' callable method."
)
return value


class Neo4jDriverModel(BaseModel):
driver: neo4j.Driver
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("driver")
def check_driver(cls, value):
if not isinstance(value, neo4j.Driver):
raise ValueError("Provided driver needs to be of type neo4j.Driver")
return value


class VectorRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None


class VectorCypherRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None


class HybridRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
vector_index_name: str
fulltext_index_name: str
embedder_model: Optional[EmbedderModel] = None
return_properties: Optional[list[str]] = None


class HybridCypherRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
vector_index_name: str
fulltext_index_name: str
retrieval_query: str
embedder_model: Optional[EmbedderModel] = None
15 changes: 15 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def test_vector_cypher_retriever_initialization(driver):
mock_verify.assert_called_once()


def test_hybrid_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
HybridRetriever(driver=driver, vector_index_name=42, fulltext_index_name=42)


def test_hybrid_cypher_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
HybridCypherRetriever(
driver=driver,
vector_index_name="my-index",
fulltext_index_name="fulltext-index",
retrieval_query=42,
)


@patch("neo4j_genai.HybridRetriever._verify_version")
def test_hybrid_search_text_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
Expand Down
17 changes: 15 additions & 2 deletions tests/unit/retrievers/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from neo4j.exceptions import CypherSyntaxError

from neo4j_genai import VectorRetriever, VectorCypherRetriever
from neo4j_genai.embedder import Embedder
from neo4j_genai.neo4j_queries import get_search_query
from neo4j_genai.types import SearchType, VectorSearchRecord

Expand All @@ -28,6 +29,18 @@ def test_vector_retriever_initialization(driver):
mock_verify.assert_called_once()


def test_vector_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
VectorRetriever(driver=driver, index_name=42)


def test_vector_cypher_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
VectorCypherRetriever(
driver=driver, index_name="my-index", retrieval_query=42
)


def test_vector_cypher_retriever_initialization(driver):
with patch("neo4j_genai.retrievers.base.Retriever._verify_version") as mock_verify:
VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query="")
Expand Down Expand Up @@ -70,7 +83,7 @@ def test_similarity_search_text_happy_path(
_verify_version_mock, _fetch_index_infos, driver
):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings = MagicMock(spec=Embedder)
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
query_text = "may thy knife chip and shatter"
Expand Down Expand Up @@ -104,7 +117,7 @@ def test_similarity_search_text_return_properties(
_verify_version_mock, _fetch_index_infos, driver
):
embed_query_vector = [1.0 for _ in range(3)]
custom_embeddings = MagicMock()
custom_embeddings = MagicMock(spec=Embedder)
custom_embeddings.embed_query.return_value = embed_query_vector
index_name = "my-index"
query_text = "may thy knife chip and shatter"
Expand Down

0 comments on commit 987afc4

Please sign in to comment.