Skip to content

Commit

Permalink
Revert changes for vector retriever prefiltering
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed May 14, 2024
1 parent 987afc4 commit 99b127f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
16 changes: 4 additions & 12 deletions src/neo4j_genai/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def __init__(
return_properties=return_properties,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
Expand Down Expand Up @@ -101,9 +99,7 @@ def search(
query_text=query_text,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

Expand Down Expand Up @@ -142,9 +138,7 @@ def __init__(
embedder_model=embedder_model,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)
raise ValueError(f"Validation failed: {e.errors()}")

super().__init__(validated_data.driver_model.driver)
self.vector_index_name = validated_data.vector_index_name
Expand Down Expand Up @@ -192,9 +186,7 @@ def search(
query_params=query_params,
)
except ValidationError as e:
msg = f"Validation failed: {e.errors()}"
logger.error(msg)
raise ValueError(msg)
raise ValueError(f"Validation failed: {e.errors()}")

parameters = validated_data.model_dump(exclude_none=True)

Expand Down
35 changes: 30 additions & 5 deletions src/neo4j_genai/retrievers/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ def __init__(
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
filters: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(driver)
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder
self._node_label = None
self._embedding_node_property = None
self._embedding_dimension = None
self._fetch_index_infos()

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
filters: Optional[dict[str, Any]] = None,
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -75,7 +81,7 @@ def search(
"""
try:
validated_data = VectorSearchModel(
index_name=self.index_name,
vector_index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand All @@ -93,7 +99,15 @@ def search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

search_query = get_search_query(SearchType.VECTOR, self.return_properties)
search_query, search_params = get_search_query(
SearchType.VECTOR,
self.return_properties,
node_label=self._node_label,
embedding_node_property=self._embedding_node_property,
embedding_dimension=self._embedding_dimension,
filters=filters,
)
parameters.update(search_params)

logger.debug("VectorRetriever Cypher parameters: %s", parameters)
logger.debug("VectorRetriever Cypher query: %s", search_query)
Expand Down Expand Up @@ -129,13 +143,18 @@ def __init__(
self.index_name = index_name
self.retrieval_query = retrieval_query
self.embedder = embedder
self._node_label = None
self._node_embedding_property = None
self._embedding_dimension = None
self._fetch_index_infos()

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
filters: Optional[dict[str, Any]] = None,
) -> list[neo4j.Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -158,7 +177,7 @@ def search(
"""
try:
validated_data = VectorCypherSearchModel(
index_name=self.index_name,
vector_index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand All @@ -181,9 +200,15 @@ def search(
parameters[key] = value
del parameters["query_params"]

search_query = get_search_query(
SearchType.VECTOR, retrieval_query=self.retrieval_query
search_query, search_params = get_search_query(
SearchType.VECTOR,
retrieval_query=self.retrieval_query,
node_label=self._node_label,
embedding_node_property=self._node_embedding_property,
embedding_dimension=self._embedding_dimension,
filters=filters,
)
parameters.update(search_params)

logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters)
logger.debug("VectorCypherRetriever Cypher query: %s", search_query)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/retrievers/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def test_vector_retriever_bad_data_validation(driver):

def test_vector_cypher_retriever_bad_data_validation(driver):
with pytest.raises(ValueError):
VectorCypherRetriever(
driver=driver, index_name="my-index", retrieval_query=42
)
VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query=42)


def test_vector_cypher_retriever_initialization(driver):
Expand Down

0 comments on commit 99b127f

Please sign in to comment.