From e9786f46c185b91ec37347ddf15e60b0ea383a48 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 17:34:50 +0200 Subject: [PATCH 01/38] Support for filters in Vector and VectorCypher retrievers --- src/neo4j_genai/neo4j_queries.py | 112 ++++++--- src/neo4j_genai/retrievers/base.py | 16 ++ src/neo4j_genai/retrievers/filters.py | 315 ++++++++++++++++++++++++++ src/neo4j_genai/retrievers/hybrid.py | 4 +- src/neo4j_genai/retrievers/vector.py | 37 ++- src/neo4j_genai/types.py | 2 +- tests/e2e/conftest.py | 7 +- tests/e2e/test_vector_e2e.py | 21 ++ tests/unit/retrievers/test_filters.py | 162 +++++++++++++ tests/unit/retrievers/test_hybrid.py | 8 +- tests/unit/retrievers/test_vector.py | 45 ++-- tests/unit/test_neo4j_queries.py | 71 ++---- 12 files changed, 676 insertions(+), 124 deletions(-) create mode 100644 src/neo4j_genai/retrievers/filters.py create mode 100644 tests/unit/retrievers/test_filters.py diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index b9ab366a..e3cf4149 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -12,47 +12,93 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any from neo4j_genai.types import SearchType +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +VECTOR_INDEX_QUERY = ( + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score" +) + +VECTOR_EXACT_QUERY = ( + "WITH node, " + "vector.similarity.cosine(node.`{embedding_node_property}`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" +) + +BASE_VECTOR_EXACT_QUERY = ( + "MATCH (node:`{node_label}`) " + "WHERE node.`{embedding_node_property}` IS NOT NULL " + "AND size(node.`{embedding_node_property}`) = toInteger($embedding_dimension)" +) + +FULL_TEXT_SEARCH_QUERY = ( + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score" +) + + +def _get_hybrid_query() -> str: + return ( + f"CALL {{ {VECTOR_INDEX_QUERY} " + f"RETURN node, score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS max " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / max) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + + +def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + base_query = BASE_VECTOR_EXACT_QUERY.format( + node_label=node_label, + embedding_node_property=embedding_node_property, + ) + vector_query = VECTOR_EXACT_QUERY.format( + embedding_node_property=embedding_node_property, + ) + query_params["embedding_dimension"] = embedding_dimension + return f"""{base_query} + AND ({where_filters}) + {vector_query} + """, query_params + + +def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + if filters: + return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return VECTOR_INDEX_QUERY, {} def get_search_query( search_type: SearchType, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, -): - query_map = { - SearchType.VECTOR: "".join( - [ - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", - "YIELD node, score ", - get_query_tail(retrieval_query, return_properties), - ] - ), - SearchType.HYBRID: "".join( - [ - "CALL { ", - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", - "YIELD node, score ", - "RETURN node, score UNION ", - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", - "YIELD node, score ", - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", - "UNWIND nodes AS n ", - "RETURN n.node AS node, (n.score / max) AS score ", - "} ", - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", - get_query_tail( - retrieval_query, return_properties, "RETURN node, score" - ), - ] - ), - } - return query_map[search_type] - - -def get_query_tail( + node_label: Optional[str] = None, + embedding_node_property: Optional[str] = None, + embedding_dimension: Optional[int] = None, + filters: Optional[dict[str, Any]] = None, +) -> tuple[str, dict[str, Any]]: + if search_type == SearchType.HYBRID: + if filters: + raise Exception("Filters is not supported with Hybrid Search") + query = _get_hybrid_query() + params = {} + elif search_type == SearchType.VECTOR: + query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + else: + raise ValueError(f"Search type is not supported: {search_type}") + query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + return " ".join([query, query_tail]), params + + +def _get_query_tail( retrieval_query: Optional[str] = None, return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index dc483eb6..24257429 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -57,3 +57,19 @@ def _verify_version(self) -> None: @abstractmethod def search(self, *args, **kwargs) -> Any: pass + + def _fetch_index_infos(self): + """Fetch the node label and embedding property from the index definition""" + query = """SHOW VECTOR INDEXES +YIELD name, labelsOrTypes, properties, options +WHERE name = $index_name +RETURN labelsOrTypes as labels, properties, options.indexConfig.`vector.dimensions` as dimensions + """ + result = self.driver.execute_query(query, {"index_name": self.index_name}) + try: + result = result.records[0] + except IndexError: + raise Exception(f"No index with name {self.index_name} found") + self._node_label = result["labels"][0] + self._embedding_node_property = result["properties"][0] + self._embedding_dimension = result["dimensions"] diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py new file mode 100644 index 00000000..358a92fc --- /dev/null +++ b/src/neo4j_genai/retrievers/filters.py @@ -0,0 +1,315 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Filters format: +{"property_name": "property_value"} + + +""" +from typing import Any, Type +from collections import Counter + + +DEFAULT_NODE_ALIAS = "node" + + +class Operator: + """Operator classes are helper classes to build the Cypher queries + from a filter like {"field_name": "field_value"} + They implement two important methods: + - lhs: (left hand side): the node + property to be filtered on + + optional operations on it (see ILikeOperator for instance) + - cleaned_value: a method to make sure the provided parameter values are + consistent with the operator (e.g. LIKE operator only works with string values) + """ + CYPHER_OPERATOR = None + + def __init__(self, node_alias=DEFAULT_NODE_ALIAS): + self.node_alias = node_alias + + def lhs(self, field): + return f"{self.node_alias}.`{field}`" + + def cleaned_value(self, value): + return value + + +class EqOperator(Operator): + CYPHER_OPERATOR = "=" + + +class NeqOperator(Operator): + CYPHER_OPERATOR = "<>" + + +class LtOperator(Operator): + CYPHER_OPERATOR = "<" + + +class GtOperator(Operator): + CYPHER_OPERATOR = ">" + + +class LteOperator(Operator): + CYPHER_OPERATOR = "<=" + + +class GteOperator(Operator): + CYPHER_OPERATOR = ">=" + + +class InOperator(Operator): + CYPHER_OPERATOR = "IN" + + def cleaned_value(self, value): + for val in value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + return value + + +class NinOperator(InOperator): + CYPHER_OPERATOR = "NOT IN" + + +class LikeOperator(Operator): + CYPHER_OPERATOR = "CONTAINS" + + def cleaned_value(self, value): + if not isinstance(value, str): + raise ValueError(f"Expected string value, got {type(value)}: {value}") + return value.rstrip("%") + + +class ILikeOperator(LikeOperator): + + def lhs(self, field): + return f"toLower({self.node_alias}.`{field}`)" + + def cleaned_value(self, value): + value = super().cleaned_value(value) + return value.lower() + + +OPERATOR_PREFIX = "$" + +OPERATOR_EQ = "$eq" +OPERATOR_NE = "$ne" +OPERATOR_LT = "$lt" +OPERATOR_LTE = "$lte" +OPERATOR_GT = "$gt" +OPERATOR_GTE = "$gte" +OPERATOR_BETWEEN = "$between" +OPERATOR_IN = "$in" +OPERATOR_NIN = "$nin" +OPERATOR_LIKE = "$like" +OPERATOR_ILIKE = "$ilike" + +OPERATOR_AND = "$and" +OPERATOR_OR = "$or" + +COMPARISONS_TO_NATIVE = { + OPERATOR_EQ: EqOperator, + OPERATOR_NE: NeqOperator, + OPERATOR_LT: LtOperator, + OPERATOR_LTE: LteOperator, + OPERATOR_GT: GtOperator, + OPERATOR_GTE: GteOperator, + OPERATOR_IN: InOperator, + OPERATOR_NIN: NinOperator, + OPERATOR_LIKE: LikeOperator, + OPERATOR_ILIKE: ILikeOperator, +} + + +LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(LOGICAL_OPERATORS) + .union({OPERATOR_BETWEEN}) +) + + +class ParameterStore: + """ + Store parameters for a given query. + Determine the parameter name depending on a parameter counter + """ + + def __init__(self): + self._counter = Counter() + self.params = {} + + def _get_params_name(self, key="param"): + """NB: the counter parameter is there in purpose, will be modified in the function + to remember the count of each parameter + + :param p: + :param counter: + :return: + """ + # key = slugify(key.replace(".", "_"), separator="_") + param_name = f"{key}_{self._counter[key]}" + self._counter[key] += 1 + return param_name + + def add(self, key, value): + param_name = self._get_params_name() + self.params[param_name] = value + return param_name + + +def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: + """Return Cypher for field operator value + NB: the param_store argument is mutable, it will be updated in this function + """ + native_op = native_operator_class() + param_name = param_store.add(field, native_op.cleaned_value(value)) + query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" + return query_snippet + + +def _handle_field_filter( + field: str, value: Any, param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS +) -> str: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_store: + node_alias: + + Returns + - Cypher filter snippet* + + NB: the param_store argument is mutable, it will be updated in this function + """ + # first, perform some sanity checks + if not isinstance(field, str): + raise ValueError( + f"Field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith(OPERATOR_PREFIX): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") + + if isinstance(value, dict): + # This is a filter specification e.g. {"$gte": 0} + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + operator = operator.lower() + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # if value is not dict, then we assume an equality operator + operator = OPERATOR_EQ + filter_value = value + + # now everything is set, we can start and build the query + # special case for the BETWEEN operator that requires + # two tests (lower_bound <= value <= higher_bound) + if operator == OPERATOR_BETWEEN: + low, high = filter_value + param_name_low = param_store.add(field, low) + param_name_high = param_store.add(field, high) + query_snippet = ( + f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" + ) + return query_snippet + # all the other operators are handled through their own classes: + native_op_class = COMPARISONS_TO_NATIVE[operator] + return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + + +def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: + """Construct a metadata filter. This is a recursive function parsing the filter dict + + Args: + filter: A dictionary representing the filter condition. + param_store: A ParamStore object that will deal with parameter naming and saving along the process + node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + + Returns: + str + + NB: the param_store argument is mutable, it will be updated in this function + """ + + if not isinstance(filter, dict): + raise ValueError() + # if we have more than one entry, this is an implicit "AND" filter + if len(filter) > 1: + return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if not key.startswith("$"): + # it's not an operator, must be a field + return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == OPERATOR_AND: + cypher_operator = " AND " + elif key.lower() == OPERATOR_OR: + cypher_operator = " OR " + else: + raise ValueError(f"Unsupported filter {filter}") + query = cypher_operator.join( + [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + ) + return query + + +def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: + """Construct the cypher filter snippet based on a filter dict + + Args: + filter: a dict of filters + node_alias: the node the filters must be applied on + + Return: + A tuple of str, dict where the string is the cypher query and the dict + contains the query parameters + """ + param_store = ParameterStore() + return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index a9311d06..7d98a925 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -81,7 +81,7 @@ def search( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query = get_search_query(SearchType.HYBRID, self.return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties) records, _, _ = self.driver.execute_query(search_query, parameters) return records @@ -154,7 +154,7 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( + search_query, _ = get_search_query( SearchType.HYBRID, retrieval_query=self.retrieval_query ) records, _, _ = self.driver.execute_query(search_query, parameters) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 771ce352..b8c07032 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -45,12 +45,17 @@ def __init__( self.index_name = index_name self.return_properties = return_properties self.embedder = embedder + self._node_label = None + self._embedding_node_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, + filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -72,7 +77,7 @@ def search( """ try: validated_data = VectorSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -90,7 +95,17 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.VECTOR, self.return_properties) + search_query, search_params = get_search_query( + SearchType.VECTOR, + self.return_properties, + node_label=self._node_label, + embedding_node_property=self._embedding_node_property, + embedding_dimension=self._embedding_dimension, + filters=filters, + ) + parameters.update(search_params) + + print(search_query, parameters) records, _, _ = self.driver.execute_query(search_query, parameters) @@ -123,6 +138,10 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder + self._node_label = None + self._node_embedding_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, @@ -130,6 +149,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, + filters: Optional[dict[str, Any]] = None, ) -> list[Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -152,7 +172,7 @@ def search( """ try: validated_data = VectorCypherSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -175,8 +195,15 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( - SearchType.VECTOR, retrieval_query=self.retrieval_query + search_query, search_params = get_search_query( + SearchType.VECTOR, + retrieval_query=self.retrieval_query, + node_label=self._node_label, + embedding_node_property=self._node_embedding_property, + embedding_dimension=self._embedding_dimension, + filters=filters, ) + parameters.update(search_params) + records, _, _ = self.driver.execute_query(search_query, parameters) return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 67a31175..285c00e9 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -54,7 +54,7 @@ def check_node_properties_not_empty(cls, v): class VectorSearchModel(BaseModel): - index_name: str + vector_index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None query_text: Optional[str] = None diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 64cd6504..442a51a0 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -56,7 +56,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=1536, + dimensions=10, similarity_fn="euclidean", ) @@ -66,7 +66,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(1536)] + vector = [random.random() for _ in range(10)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -74,6 +74,8 @@ def random_str(n: int) -> str: for i in range(10): insert_query = ( "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $i, " + " doc.short_text_property = toString($i)" "WITH doc " "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" "WITH doc " @@ -84,6 +86,7 @@ def random_str(n: int) -> str: parameters = { "id": str(uuid.uuid4()), + "i": i, "vector": vector, "authorName": random_str(10), } diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index 9bf3f5a4..baeae191 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -102,3 +102,24 @@ def test_vector_retriever_return_properties(driver): assert len(results) == 5 for result in results: assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_filters(driver): + retriever = VectorRetriever( + driver, + "vector-index-name", + ) + + top_k = 2 + results = retriever.search( + query_vector=[1.0 for _ in range(10)], + filters={"int_property": {"$gt": 2}}, + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 2 + for result in results: + assert isinstance(result, VectorSearchRecord) + assert result.node["int_property"] > 2 diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py new file mode 100644 index 00000000..536f9491 --- /dev/null +++ b/tests/unit/retrievers/test_filters.py @@ -0,0 +1,162 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +def test_filter_single_field_string(): + filters = {"field": "string_value"} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_single_field_int(): + filters = {"field": 28} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": 28} + + +def test_filter_single_field_bool(): + filters = {"field": False} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": False} + + +def test_filter_explicit_eq_operator(): + filters = {"field": {"$eq": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_neq_operator(): + filters = {"field": {"$ne": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <> $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_lt_operator(): + filters = {"field": {"$lt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` < $param_0" + assert params == {"param_0": 1} + + +def test_filter_gt_operator(): + filters = {"field": {"$gt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` > $param_0" + assert params == {"param_0": 1} + + +def test_filter_lte_operator(): + filters = {"field": {"$lte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <= $param_0" + assert params == {"param_0": 1} + + +def test_filter_gte_operator(): + filters = {"field": {"$gte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` >= $param_0" + assert params == {"param_0": 1} + + +def test_filter_in_operator(): + filters = {"field": {"$in": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_not_in_operator(): + filters = {"field": {"$nin": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` NOT IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_like_operator(): + filters = {"field": {"$like": "some_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` CONTAINS $param_0" + assert params == {"param_0": "some_value"} + + +def test_filter_ilike_operator(): + filters = {"field": {"$ilike": "Some Value"}} + query, params = construct_metadata_filter(filters) + assert query == "toLower(node.`field`) CONTAINS $param_0" + assert params == {"param_0": "some value"} + + +def test_filter_between_operator(): + filters = {"field": {"$between": [0, 1]}} + query, params = construct_metadata_filter(filters) + assert query == "$param_0 <= node.`field` <= $param_1" + assert params == {"param_0": 0, "param_1": 1} + + +def test_filter_implicit_and_condition(): + filters = {"field_1": "string_value", "field_2": True} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_explicit_and_condition(): + filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_or_condition(): + filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_and_or_combined(): + filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} + query, params = construct_metadata_filter(filters) + assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} + + +# now testing bad filters +def test_field_name_with_dollar_sign(): + filters = {"$field": "value"} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_and_no_list(): + filters = {"$and": {}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_unsupported_operator(): + filters = {"field": {"$unsupported": "value"}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index b55e3c54..093364a6 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -60,7 +60,7 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -98,7 +98,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -161,7 +161,7 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID, return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -206,7 +206,7 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 69c1f615..9be9da60 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -34,8 +34,9 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, driver): +def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -46,14 +47,14 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, @@ -61,8 +62,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, driver): +def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -75,7 +77,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) @@ -83,7 +85,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -92,8 +94,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, driver): +def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -111,7 +114,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, return_properties) + search_query, _ = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -119,7 +122,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query.rstrip(), { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -175,8 +178,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri ) +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, driver): +def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -187,7 +191,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -195,15 +199,16 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, ) +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_happy_path(_verify_version_mock, driver): +def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -221,7 +226,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -232,7 +237,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -240,8 +245,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_with_params(_verify_version_mock, driver): +def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -265,7 +271,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -278,7 +284,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, "param": "dummy-param", @@ -288,8 +294,9 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_cypher_error(_verify_version_mock, driver): +def test_retrieval_query_cypher_error(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 3ce7c774..d20185b2 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType def test_vector_search_basic(): expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score" + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score" ) - result = get_search_query(SearchType.VECTOR) + result, params = get_search_query(SearchType.VECTOR) assert result.strip() == expected.strip() + assert params == {} def test_hybrid_search_basic(): @@ -41,28 +43,28 @@ def test_hybrid_search_basic(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node, score" ) - result = get_search_query(SearchType.HYBRID) + result, _ = get_search_query(SearchType.HYBRID) assert result.strip() == expected.strip() def test_vector_search_with_properties(): properties = ["name", "age"] expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) + result, _ = get_search_query(SearchType.VECTOR, return_properties=properties) assert result.strip() == expected.strip() def test_vector_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " + retrieval_query ) - result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -82,7 +84,7 @@ def test_hybrid_search_with_retrieval_query(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + retrieval_query ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -102,52 +104,5 @@ def test_hybrid_search_with_properties(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.HYBRID, return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_retrieval_query(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = retrieval_query - result = get_query_tail(retrieval_query=retrieval_query) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_properties(): - properties = ["name", "age"] - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail(return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_fallback(): - fallback = "HELLO" - expected = fallback - result = get_query_tail(fallback_return=fallback) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_all(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - properties = ["name", "age"] - fallback = "HELLO" - - expected = retrieval_query - result = get_query_tail( - retrieval_query=retrieval_query, - return_properties=properties, - fallback_return=fallback, - ) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_no_retrieval_query(): - properties = ["name", "age"] - fallback = "HELLO" - - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail( - return_properties=properties, - fallback_return=fallback, - ) + result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() From b7ef345e263777b87913970f67aa4028eb26cf82 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:00:38 +0200 Subject: [PATCH 02/38] Ruff --- src/neo4j_genai/neo4j_queries.py | 33 +++++++++++---- src/neo4j_genai/retrievers/filters.py | 59 ++++++++++++++++----------- tests/unit/retrievers/test_filters.py | 20 ++++++--- tests/unit/retrievers/test_hybrid.py | 4 +- tests/unit/retrievers/test_vector.py | 24 ++++++++--- 5 files changed, 97 insertions(+), 43 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index e3cf4149..5897fa5a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -54,7 +54,12 @@ def _get_hybrid_query() -> str: ) -def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_filtered_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -64,15 +69,25 @@ def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embeddi embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return f"""{base_query} + return ( + f"""{base_query} AND ({where_filters}) {vector_query} - """, query_params + """, + query_params, + ) -def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: if filters: - return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return _get_filtered_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) return VECTOR_INDEX_QUERY, {} @@ -91,10 +106,14 @@ def get_search_query( query = _get_hybrid_query() params = {} elif search_type == SearchType.VECTOR: - query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + query, params = _get_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) else: raise ValueError(f"Search type is not supported: {search_type}") - query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + query_tail = _get_query_tail( + retrieval_query, return_properties, fallback_return="RETURN node, score" + ) return " ".join([query, query_tail]), params diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 358a92fc..0919c237 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -12,12 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Filters format: -{"property_name": "property_value"} - - -""" from typing import Any, Type from collections import Counter @@ -34,6 +28,7 @@ class Operator: - cleaned_value: a method to make sure the provided parameter values are consistent with the operator (e.g. LIKE operator only works with string values) """ + CYPHER_OPERATOR = None def __init__(self, node_alias=DEFAULT_NODE_ALIAS): @@ -96,7 +91,6 @@ def cleaned_value(self, value): class ILikeOperator(LikeOperator): - def lhs(self, field): return f"toLower({self.node_alias}.`{field}`)" @@ -139,9 +133,7 @@ def cleaned_value(self, value): LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(LOGICAL_OPERATORS) - .union({OPERATOR_BETWEEN}) + set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) ) @@ -174,7 +166,13 @@ def add(self, key, value): return param_name -def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: +def _single_condition_cypher( + field: str, + native_operator_class: Type[Operator], + value: Any, + param_store: ParameterStore, + node_alias: str, +) -> str: """Return Cypher for field operator value NB: the param_store argument is mutable, it will be updated in this function """ @@ -185,8 +183,10 @@ def _single_condition_cypher(field: str, native_operator_class: Type[Operator], def _handle_field_filter( - field: str, value: Any, param_store: ParameterStore, - node_alias: str = DEFAULT_NODE_ALIAS + field: str, + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Create a filter for a specific field. @@ -254,10 +254,14 @@ def _handle_field_filter( return query_snippet # all the other operators are handled through their own classes: native_op_class = COMPARISONS_TO_NATIVE[operator] - return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + return _single_condition_cypher( + field, native_op_class, filter_value, param_store, node_alias + ) -def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: +def _construct_metadata_filter( + filter: dict[str, Any], param_store: ParameterStore, node_alias: str +) -> str: """Construct a metadata filter. This is a recursive function parsing the filter dict Args: @@ -275,19 +279,21 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto raise ValueError() # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: - return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + return _construct_metadata_filter( + {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias + ) # The only operators allowed at the top level are $AND and $OR # First check if an operator or a field key, value = list(filter.items())[0] if not key.startswith("$"): # it's not an operator, must be a field - return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + return _handle_field_filter( + key, filter[key], param_store, node_alias=node_alias + ) # Here we handle the $and and $or operators if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") if key.lower() == OPERATOR_AND: cypher_operator = " AND " elif key.lower() == OPERATOR_OR: @@ -295,12 +301,17 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto else: raise ValueError(f"Unsupported filter {filter}") query = cypher_operator.join( - [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + [ + f"({ _construct_metadata_filter(el, param_store, node_alias)})" + for el in value + ] ) return query -def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: +def construct_metadata_filter( + filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS +) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: @@ -312,4 +323,6 @@ def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_ contains the query parameters """ param_store = ParameterStore() - return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params + return _construct_metadata_filter( + filter, param_store, node_alias=node_alias + ), param_store.params diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index 536f9491..fd562118 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -137,9 +137,17 @@ def test_filter_or_condition(): def test_filter_and_or_combined(): - filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} - query, params = construct_metadata_filter(filters) - assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + filters = { + "$and": [ + {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, + {"field_3": 11}, + ] + } + query, params = construct_metadata_filter(filters) + assert ( + query + == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} @@ -147,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 093364a6..79486835 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -206,7 +206,9 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.HYBRID, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 9be9da60..c3fd1ade 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -36,7 +36,9 @@ def test_vector_cypher_retriever_initialization(driver): @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -64,7 +66,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_ @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -96,7 +100,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_in @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_return_properties( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -180,7 +186,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_bad_results( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -226,7 +234,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, dr None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, @@ -271,7 +281,9 @@ def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, d None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, From 0c918cd53e5bcc3217d712f9be1402870ab387bb Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:04:56 +0200 Subject: [PATCH 03/38] Back to the normal dimension size in e2e tests --- tests/e2e/conftest.py | 6 +++--- tests/e2e/test_vector_e2e.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 442a51a0..3176003a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -56,7 +56,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=10, + dimensions=1536, similarity_fn="euclidean", ) @@ -66,7 +66,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(10)] + vector = [random.random() for _ in range(1536)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -88,6 +88,6 @@ def random_str(n: int) -> str: "id": str(uuid.uuid4()), "i": i, "vector": vector, - "authorName": random_str(10), + "authorName": random_str(1536), } driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index baeae191..608dd4d0 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -113,7 +113,7 @@ def test_vector_retriever_filters(driver): top_k = 2 results = retriever.search( - query_vector=[1.0 for _ in range(10)], + query_vector=[1.0 for _ in range(1536)], filters={"int_property": {"$gt": 2}}, top_k=top_k, ) From a6cf57711c9f2e5b42b3aa7750ebe6085872fd4c Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:21:49 +0200 Subject: [PATCH 04/38] Improved docstrings + include an example --- examples/vector_search_with_filters.py | 72 ++++++++++++++++++++++++++ src/neo4j_genai/neo4j_queries.py | 59 +++++++++++++++++++-- src/neo4j_genai/retrievers/filters.py | 37 +++++++++---- 3 files changed, 155 insertions(+), 13 deletions(-) create mode 100644 examples/vector_search_with_filters.py diff --git a/examples/vector_search_with_filters.py b/examples/vector_search_with_filters.py new file mode 100644 index 00000000..bf5fa444 --- /dev/null +++ b/examples/vector_search_with_filters.py @@ -0,0 +1,72 @@ +from neo4j import GraphDatabase +from neo4j_genai import VectorRetriever + +import random +import string +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index + + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(DIMENSION)] + + +# Generate random strings +def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME, embedder) + +# Upsert the query +vector = [random.random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $id, " + " doc.short_text_property = toString($id)" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" +) +parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), +} +driver.execute_query(insert_query, parameters) + +# Perform the search +query_text = "Find me a book about Fremen" +print( + retriever.search( + query_text=query_text, top_k=1, filters={"int_property": {"$gt": 100}} + ) +) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 5897fa5a..52bbf332 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -60,6 +60,18 @@ def _get_filtered_vector_query( embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build Cypher query for vector search with filters + Uses exact KNN. + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + """ where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -71,19 +83,31 @@ def _get_filtered_vector_query( query_params["embedding_dimension"] = embedding_dimension return ( f"""{base_query} - AND ({where_filters}) - {vector_query} + AND ({where_filters}) + {vector_query} """, query_params, ) def _get_vector_query( - filters: dict[str, Any], + filters: Optional[dict[str, Any]], node_label: str, embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build the vector query with or without filters + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if filters: return _get_filtered_vector_query( filters, node_label, embedding_node_property, embedding_dimension @@ -100,6 +124,23 @@ def get_search_query( embedding_dimension: Optional[int] = None, filters: Optional[dict[str, Any]] = None, ) -> tuple[str, dict[str, Any]]: + """Build the search query, including pre-filtering if needed, and return clause. + + Args + search_type: Search type we want to search for: + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if search_type == SearchType.HYBRID: if filters: raise Exception("Filters is not supported with Hybrid Search") @@ -122,6 +163,18 @@ def _get_query_tail( return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, ) -> str: + """Build the RETURN statement after the search is performed + + Args + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + fallback_return (str): the fallback return statement to use to retrieve the search results + + Returns: + str: the RETURN statement + """ if retrieval_query: return retrieval_query if return_properties: diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 0919c237..fc052fa7 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -148,8 +148,11 @@ def __init__(self): self.params = {} def _get_params_name(self, key="param"): - """NB: the counter parameter is there in purpose, will be modified in the function - to remember the count of each parameter + """Find parameter name so that param names are unique. + This function adds a suffix to the key corresponding to the number + of times the key have been used in the query. + E.g. + node.age >= $param_0 AND node.age <= $param_1 :param p: :param counter: @@ -161,6 +164,9 @@ def _get_params_name(self, key="param"): return param_name def add(self, key, value): + """This function adds a new parameter to the param dict. + It returns the name of the parameter to be used as a placeholder + in the cypher query, e.g. $param_0""" param_name = self._get_params_name() self.params[param_name] = value return param_name @@ -173,10 +179,21 @@ def _single_condition_cypher( param_store: ParameterStore, node_alias: str, ) -> str: - """Return Cypher for field operator value + """Return Cypher for field operator value. + + Args: + field: the name of the field being filtered + native_operator_class: the operator class that will be used to generate + the Cypher query + value: filtered value + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query + Returns: + str: the Cypher condition, e.g. node.`property` = $param_0 + NB: the param_store argument is mutable, it will be updated in this function """ - native_op = native_operator_class() + native_op = native_operator_class(node_alias=node_alias) param_name = param_store.add(field, native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -196,11 +213,11 @@ def _handle_field_filter( If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by - param_store: - node_alias: + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns - - Cypher filter snippet* + str: Cypher filter snippet NB: the param_store argument is mutable, it will be updated in this function """ @@ -266,11 +283,11 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. - param_store: A ParamStore object that will deal with parameter naming and saving along the process - node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns: - str + str: the Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ From 25688fe84f69468dade9ae15da08991f22b012dd Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 09:20:33 +0200 Subject: [PATCH 05/38] Re-add tests for the _get_query_tail function (deleted by mistake) --- tests/unit/test_neo4j_queries.py | 49 +++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index d20185b2..0d420c51 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType @@ -106,3 +106,50 @@ def test_hybrid_search_with_properties(): ) result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() + + +def test_get_query_tail_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = retrieval_query + result = _get_query_tail(retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_properties(): + properties = ["name", "age"] + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail(return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_fallback(): + fallback = "HELLO" + expected = fallback + result = _get_query_tail(fallback_return=fallback) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_all(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + properties = ["name", "age"] + fallback = "HELLO" + + expected = retrieval_query + result = _get_query_tail( + retrieval_query=retrieval_query, + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_no_retrieval_query(): + properties = ["name", "age"] + fallback = "HELLO" + + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail( + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() From 77c06794aadaad935fd952c1c8372425c3a0ad4b Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 13:26:35 +0200 Subject: [PATCH 06/38] Update docstrings, move filters file, rename function --- src/neo4j_genai/{retrievers => }/filters.py | 34 ++++++++-------- src/neo4j_genai/neo4j_queries.py | 4 +- tests/unit/retrievers/test_filters.py | 44 ++++++++++----------- 3 files changed, 42 insertions(+), 40 deletions(-) rename src/neo4j_genai/{retrievers => }/filters.py (92%) diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/filters.py similarity index 92% rename from src/neo4j_genai/retrievers/filters.py rename to src/neo4j_genai/filters.py index fc052fa7..16699cc2 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/filters.py @@ -154,9 +154,10 @@ def _get_params_name(self, key="param"): E.g. node.age >= $param_0 AND node.age <= $param_1 - :param p: - :param counter: - :return: + Args: + key (str): The prefix for the parameter name + Returns: + The full unique parameter name """ # key = slugify(key.replace(".", "_"), separator="_") param_name = f"{key}_{self._counter[key]}" @@ -182,14 +183,14 @@ def _single_condition_cypher( """Return Cypher for field operator value. Args: - field: the name of the field being filtered - native_operator_class: the operator class that will be used to generate + field: The name of the field being filtered + native_operator_class: The operator class that will be used to generate the Cypher query value: filtered value param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher condition, e.g. node.`property` = $param_0 + str: The Cypher condition, e.g. node.`property` = $param_0 NB: the param_store argument is mutable, it will be updated in this function """ @@ -208,13 +209,13 @@ def _handle_field_filter( """Create a filter for a specific field. Args: - field: name of field - value: value to filter + field: Name of field + value: Value to filter If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns str: Cypher filter snippet @@ -284,16 +285,16 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher WHERE clause + str: The Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ if not isinstance(filter, dict): - raise ValueError() + raise ValueError(f"Filter must be a dictionary, received {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -326,14 +327,15 @@ def _construct_metadata_filter( return query -def construct_metadata_filter( +def get_metadata_filter( filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS ) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: - filter: a dict of filters - node_alias: the node the filters must be applied on + filter (dict): The filters to be converted to Cypher + node_alias (str): The alias of node the filters must be applied on + in the Cypher query Return: A tuple of str, dict where the string is the cypher query and the dict diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 52bbf332..014ebb4a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -15,7 +15,7 @@ from typing import Optional, Any from neo4j_genai.types import SearchType -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter VECTOR_INDEX_QUERY = ( @@ -72,7 +72,7 @@ def _get_filtered_vector_query( Returns: tuple[str, dict[str, Any]]: query and parameters """ - where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + where_filters, query_params = get_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, embedding_node_property=embedding_node_property, diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index fd562118..b6eb0e63 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -14,124 +14,124 @@ # limitations under the License. import pytest -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter def test_filter_single_field_string(): filters = {"field": "string_value"} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_single_field_int(): filters = {"field": 28} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": 28} def test_filter_single_field_bool(): filters = {"field": False} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": False} def test_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <> $param_0" assert params == {"param_0": "string_value"} def test_filter_lt_operator(): filters = {"field": {"$lt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` < $param_0" assert params == {"param_0": 1} def test_filter_gt_operator(): filters = {"field": {"$gt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` > $param_0" assert params == {"param_0": 1} def test_filter_lte_operator(): filters = {"field": {"$lte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <= $param_0" assert params == {"param_0": 1} def test_filter_gte_operator(): filters = {"field": {"$gte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` >= $param_0" assert params == {"param_0": 1} def test_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` NOT IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_like_operator(): filters = {"field": {"$like": "some_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` CONTAINS $param_0" assert params == {"param_0": "some_value"} def test_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "toLower(node.`field`) CONTAINS $param_0" assert params == {"param_0": "some value"} def test_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "$param_0 <= node.`field` <= $param_1" assert params == {"param_0": 0, "param_1": 1} def test_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} @@ -143,7 +143,7 @@ def test_filter_and_or_combined(): {"field_3": 11}, ] } - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert ( query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" @@ -155,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) From f2a830ea44e8251bb25f421c4ede53faba61d41c Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 17:34:50 +0200 Subject: [PATCH 07/38] Support for filters in Vector and VectorCypher retrievers --- src/neo4j_genai/neo4j_queries.py | 112 ++++++--- src/neo4j_genai/retrievers/base.py | 16 ++ src/neo4j_genai/retrievers/filters.py | 315 ++++++++++++++++++++++++++ src/neo4j_genai/retrievers/hybrid.py | 4 +- src/neo4j_genai/retrievers/vector.py | 34 ++- src/neo4j_genai/types.py | 2 +- tests/e2e/conftest.py | 7 +- tests/e2e/test_vector_e2e.py | 21 ++ tests/unit/retrievers/test_filters.py | 162 +++++++++++++ tests/unit/retrievers/test_hybrid.py | 8 +- tests/unit/retrievers/test_vector.py | 45 ++-- tests/unit/test_neo4j_queries.py | 71 ++---- 12 files changed, 673 insertions(+), 124 deletions(-) create mode 100644 src/neo4j_genai/retrievers/filters.py create mode 100644 tests/unit/retrievers/test_filters.py diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index b9ab366a..e3cf4149 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -12,47 +12,93 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any from neo4j_genai.types import SearchType +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +VECTOR_INDEX_QUERY = ( + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score" +) + +VECTOR_EXACT_QUERY = ( + "WITH node, " + "vector.similarity.cosine(node.`{embedding_node_property}`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" +) + +BASE_VECTOR_EXACT_QUERY = ( + "MATCH (node:`{node_label}`) " + "WHERE node.`{embedding_node_property}` IS NOT NULL " + "AND size(node.`{embedding_node_property}`) = toInteger($embedding_dimension)" +) + +FULL_TEXT_SEARCH_QUERY = ( + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score" +) + + +def _get_hybrid_query() -> str: + return ( + f"CALL {{ {VECTOR_INDEX_QUERY} " + f"RETURN node, score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS max " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / max) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + + +def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + base_query = BASE_VECTOR_EXACT_QUERY.format( + node_label=node_label, + embedding_node_property=embedding_node_property, + ) + vector_query = VECTOR_EXACT_QUERY.format( + embedding_node_property=embedding_node_property, + ) + query_params["embedding_dimension"] = embedding_dimension + return f"""{base_query} + AND ({where_filters}) + {vector_query} + """, query_params + + +def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + if filters: + return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return VECTOR_INDEX_QUERY, {} def get_search_query( search_type: SearchType, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, -): - query_map = { - SearchType.VECTOR: "".join( - [ - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", - "YIELD node, score ", - get_query_tail(retrieval_query, return_properties), - ] - ), - SearchType.HYBRID: "".join( - [ - "CALL { ", - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", - "YIELD node, score ", - "RETURN node, score UNION ", - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", - "YIELD node, score ", - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", - "UNWIND nodes AS n ", - "RETURN n.node AS node, (n.score / max) AS score ", - "} ", - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", - get_query_tail( - retrieval_query, return_properties, "RETURN node, score" - ), - ] - ), - } - return query_map[search_type] - - -def get_query_tail( + node_label: Optional[str] = None, + embedding_node_property: Optional[str] = None, + embedding_dimension: Optional[int] = None, + filters: Optional[dict[str, Any]] = None, +) -> tuple[str, dict[str, Any]]: + if search_type == SearchType.HYBRID: + if filters: + raise Exception("Filters is not supported with Hybrid Search") + query = _get_hybrid_query() + params = {} + elif search_type == SearchType.VECTOR: + query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + else: + raise ValueError(f"Search type is not supported: {search_type}") + query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + return " ".join([query, query_tail]), params + + +def _get_query_tail( retrieval_query: Optional[str] = None, return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index dc483eb6..24257429 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -57,3 +57,19 @@ def _verify_version(self) -> None: @abstractmethod def search(self, *args, **kwargs) -> Any: pass + + def _fetch_index_infos(self): + """Fetch the node label and embedding property from the index definition""" + query = """SHOW VECTOR INDEXES +YIELD name, labelsOrTypes, properties, options +WHERE name = $index_name +RETURN labelsOrTypes as labels, properties, options.indexConfig.`vector.dimensions` as dimensions + """ + result = self.driver.execute_query(query, {"index_name": self.index_name}) + try: + result = result.records[0] + except IndexError: + raise Exception(f"No index with name {self.index_name} found") + self._node_label = result["labels"][0] + self._embedding_node_property = result["properties"][0] + self._embedding_dimension = result["dimensions"] diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py new file mode 100644 index 00000000..358a92fc --- /dev/null +++ b/src/neo4j_genai/retrievers/filters.py @@ -0,0 +1,315 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Filters format: +{"property_name": "property_value"} + + +""" +from typing import Any, Type +from collections import Counter + + +DEFAULT_NODE_ALIAS = "node" + + +class Operator: + """Operator classes are helper classes to build the Cypher queries + from a filter like {"field_name": "field_value"} + They implement two important methods: + - lhs: (left hand side): the node + property to be filtered on + + optional operations on it (see ILikeOperator for instance) + - cleaned_value: a method to make sure the provided parameter values are + consistent with the operator (e.g. LIKE operator only works with string values) + """ + CYPHER_OPERATOR = None + + def __init__(self, node_alias=DEFAULT_NODE_ALIAS): + self.node_alias = node_alias + + def lhs(self, field): + return f"{self.node_alias}.`{field}`" + + def cleaned_value(self, value): + return value + + +class EqOperator(Operator): + CYPHER_OPERATOR = "=" + + +class NeqOperator(Operator): + CYPHER_OPERATOR = "<>" + + +class LtOperator(Operator): + CYPHER_OPERATOR = "<" + + +class GtOperator(Operator): + CYPHER_OPERATOR = ">" + + +class LteOperator(Operator): + CYPHER_OPERATOR = "<=" + + +class GteOperator(Operator): + CYPHER_OPERATOR = ">=" + + +class InOperator(Operator): + CYPHER_OPERATOR = "IN" + + def cleaned_value(self, value): + for val in value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + return value + + +class NinOperator(InOperator): + CYPHER_OPERATOR = "NOT IN" + + +class LikeOperator(Operator): + CYPHER_OPERATOR = "CONTAINS" + + def cleaned_value(self, value): + if not isinstance(value, str): + raise ValueError(f"Expected string value, got {type(value)}: {value}") + return value.rstrip("%") + + +class ILikeOperator(LikeOperator): + + def lhs(self, field): + return f"toLower({self.node_alias}.`{field}`)" + + def cleaned_value(self, value): + value = super().cleaned_value(value) + return value.lower() + + +OPERATOR_PREFIX = "$" + +OPERATOR_EQ = "$eq" +OPERATOR_NE = "$ne" +OPERATOR_LT = "$lt" +OPERATOR_LTE = "$lte" +OPERATOR_GT = "$gt" +OPERATOR_GTE = "$gte" +OPERATOR_BETWEEN = "$between" +OPERATOR_IN = "$in" +OPERATOR_NIN = "$nin" +OPERATOR_LIKE = "$like" +OPERATOR_ILIKE = "$ilike" + +OPERATOR_AND = "$and" +OPERATOR_OR = "$or" + +COMPARISONS_TO_NATIVE = { + OPERATOR_EQ: EqOperator, + OPERATOR_NE: NeqOperator, + OPERATOR_LT: LtOperator, + OPERATOR_LTE: LteOperator, + OPERATOR_GT: GtOperator, + OPERATOR_GTE: GteOperator, + OPERATOR_IN: InOperator, + OPERATOR_NIN: NinOperator, + OPERATOR_LIKE: LikeOperator, + OPERATOR_ILIKE: ILikeOperator, +} + + +LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(LOGICAL_OPERATORS) + .union({OPERATOR_BETWEEN}) +) + + +class ParameterStore: + """ + Store parameters for a given query. + Determine the parameter name depending on a parameter counter + """ + + def __init__(self): + self._counter = Counter() + self.params = {} + + def _get_params_name(self, key="param"): + """NB: the counter parameter is there in purpose, will be modified in the function + to remember the count of each parameter + + :param p: + :param counter: + :return: + """ + # key = slugify(key.replace(".", "_"), separator="_") + param_name = f"{key}_{self._counter[key]}" + self._counter[key] += 1 + return param_name + + def add(self, key, value): + param_name = self._get_params_name() + self.params[param_name] = value + return param_name + + +def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: + """Return Cypher for field operator value + NB: the param_store argument is mutable, it will be updated in this function + """ + native_op = native_operator_class() + param_name = param_store.add(field, native_op.cleaned_value(value)) + query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" + return query_snippet + + +def _handle_field_filter( + field: str, value: Any, param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS +) -> str: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_store: + node_alias: + + Returns + - Cypher filter snippet* + + NB: the param_store argument is mutable, it will be updated in this function + """ + # first, perform some sanity checks + if not isinstance(field, str): + raise ValueError( + f"Field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith(OPERATOR_PREFIX): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") + + if isinstance(value, dict): + # This is a filter specification e.g. {"$gte": 0} + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + operator = operator.lower() + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # if value is not dict, then we assume an equality operator + operator = OPERATOR_EQ + filter_value = value + + # now everything is set, we can start and build the query + # special case for the BETWEEN operator that requires + # two tests (lower_bound <= value <= higher_bound) + if operator == OPERATOR_BETWEEN: + low, high = filter_value + param_name_low = param_store.add(field, low) + param_name_high = param_store.add(field, high) + query_snippet = ( + f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" + ) + return query_snippet + # all the other operators are handled through their own classes: + native_op_class = COMPARISONS_TO_NATIVE[operator] + return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + + +def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: + """Construct a metadata filter. This is a recursive function parsing the filter dict + + Args: + filter: A dictionary representing the filter condition. + param_store: A ParamStore object that will deal with parameter naming and saving along the process + node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + + Returns: + str + + NB: the param_store argument is mutable, it will be updated in this function + """ + + if not isinstance(filter, dict): + raise ValueError() + # if we have more than one entry, this is an implicit "AND" filter + if len(filter) > 1: + return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if not key.startswith("$"): + # it's not an operator, must be a field + return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == OPERATOR_AND: + cypher_operator = " AND " + elif key.lower() == OPERATOR_OR: + cypher_operator = " OR " + else: + raise ValueError(f"Unsupported filter {filter}") + query = cypher_operator.join( + [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + ) + return query + + +def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: + """Construct the cypher filter snippet based on a filter dict + + Args: + filter: a dict of filters + node_alias: the node the filters must be applied on + + Return: + A tuple of str, dict where the string is the cypher query and the dict + contains the query parameters + """ + param_store = ParameterStore() + return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 0690555a..0f6255b0 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -84,7 +84,7 @@ def search( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query = get_search_query(SearchType.HYBRID, self.return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties) logger.debug("HybridRetriever Cypher parameters: %s", parameters) logger.debug("HybridRetriever Cypher query: %s", search_query) @@ -160,7 +160,7 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( + search_query, _ = get_search_query( SearchType.HYBRID, retrieval_query=self.retrieval_query ) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 954cd04e..617ccb48 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -48,12 +48,17 @@ def __init__( self.index_name = index_name self.return_properties = return_properties self.embedder = embedder + self._node_label = None + self._embedding_node_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, + filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -75,7 +80,7 @@ def search( """ try: validated_data = VectorSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -93,7 +98,15 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.VECTOR, self.return_properties) + search_query, search_params = get_search_query( + SearchType.VECTOR, + self.return_properties, + node_label=self._node_label, + embedding_node_property=self._embedding_node_property, + embedding_dimension=self._embedding_dimension, + filters=filters, + ) + parameters.update(search_params) logger.debug("VectorRetriever Cypher parameters: %s", parameters) logger.debug("VectorRetriever Cypher query: %s", search_query) @@ -129,6 +142,10 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder + self._node_label = None + self._node_embedding_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, @@ -136,6 +153,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, + filters: Optional[dict[str, Any]] = None, ) -> list[Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -158,7 +176,7 @@ def search( """ try: validated_data = VectorCypherSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -181,9 +199,15 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( - SearchType.VECTOR, retrieval_query=self.retrieval_query + search_query, search_params = get_search_query( + SearchType.VECTOR, + retrieval_query=self.retrieval_query, + node_label=self._node_label, + embedding_node_property=self._node_embedding_property, + embedding_dimension=self._embedding_dimension, + filters=filters, ) + parameters.update(search_params) logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters) logger.debug("VectorCypherRetriever Cypher query: %s", search_query) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 67a31175..285c00e9 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -54,7 +54,7 @@ def check_node_properties_not_empty(cls, v): class VectorSearchModel(BaseModel): - index_name: str + vector_index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None query_text: Optional[str] = None diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 64cd6504..442a51a0 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -56,7 +56,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=1536, + dimensions=10, similarity_fn="euclidean", ) @@ -66,7 +66,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(1536)] + vector = [random.random() for _ in range(10)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -74,6 +74,8 @@ def random_str(n: int) -> str: for i in range(10): insert_query = ( "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $i, " + " doc.short_text_property = toString($i)" "WITH doc " "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" "WITH doc " @@ -84,6 +86,7 @@ def random_str(n: int) -> str: parameters = { "id": str(uuid.uuid4()), + "i": i, "vector": vector, "authorName": random_str(10), } diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index 9bf3f5a4..baeae191 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -102,3 +102,24 @@ def test_vector_retriever_return_properties(driver): assert len(results) == 5 for result in results: assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_filters(driver): + retriever = VectorRetriever( + driver, + "vector-index-name", + ) + + top_k = 2 + results = retriever.search( + query_vector=[1.0 for _ in range(10)], + filters={"int_property": {"$gt": 2}}, + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 2 + for result in results: + assert isinstance(result, VectorSearchRecord) + assert result.node["int_property"] > 2 diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py new file mode 100644 index 00000000..536f9491 --- /dev/null +++ b/tests/unit/retrievers/test_filters.py @@ -0,0 +1,162 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +def test_filter_single_field_string(): + filters = {"field": "string_value"} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_single_field_int(): + filters = {"field": 28} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": 28} + + +def test_filter_single_field_bool(): + filters = {"field": False} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": False} + + +def test_filter_explicit_eq_operator(): + filters = {"field": {"$eq": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_neq_operator(): + filters = {"field": {"$ne": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <> $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_lt_operator(): + filters = {"field": {"$lt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` < $param_0" + assert params == {"param_0": 1} + + +def test_filter_gt_operator(): + filters = {"field": {"$gt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` > $param_0" + assert params == {"param_0": 1} + + +def test_filter_lte_operator(): + filters = {"field": {"$lte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <= $param_0" + assert params == {"param_0": 1} + + +def test_filter_gte_operator(): + filters = {"field": {"$gte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` >= $param_0" + assert params == {"param_0": 1} + + +def test_filter_in_operator(): + filters = {"field": {"$in": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_not_in_operator(): + filters = {"field": {"$nin": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` NOT IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_like_operator(): + filters = {"field": {"$like": "some_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` CONTAINS $param_0" + assert params == {"param_0": "some_value"} + + +def test_filter_ilike_operator(): + filters = {"field": {"$ilike": "Some Value"}} + query, params = construct_metadata_filter(filters) + assert query == "toLower(node.`field`) CONTAINS $param_0" + assert params == {"param_0": "some value"} + + +def test_filter_between_operator(): + filters = {"field": {"$between": [0, 1]}} + query, params = construct_metadata_filter(filters) + assert query == "$param_0 <= node.`field` <= $param_1" + assert params == {"param_0": 0, "param_1": 1} + + +def test_filter_implicit_and_condition(): + filters = {"field_1": "string_value", "field_2": True} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_explicit_and_condition(): + filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_or_condition(): + filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_and_or_combined(): + filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} + query, params = construct_metadata_filter(filters) + assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} + + +# now testing bad filters +def test_field_name_with_dollar_sign(): + filters = {"$field": "value"} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_and_no_list(): + filters = {"$and": {}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_unsupported_operator(): + filters = {"field": {"$unsupported": "value"}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index b55e3c54..093364a6 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -60,7 +60,7 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -98,7 +98,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -161,7 +161,7 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID, return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -206,7 +206,7 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 69c1f615..9be9da60 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -34,8 +34,9 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, driver): +def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -46,14 +47,14 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, @@ -61,8 +62,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, driver): +def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -75,7 +77,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) @@ -83,7 +85,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -92,8 +94,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, driver): +def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -111,7 +114,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, return_properties) + search_query, _ = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -119,7 +122,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query.rstrip(), { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -175,8 +178,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri ) +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, driver): +def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -187,7 +191,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -195,15 +199,16 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, ) +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_happy_path(_verify_version_mock, driver): +def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -221,7 +226,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -232,7 +237,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -240,8 +245,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_with_params(_verify_version_mock, driver): +def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -265,7 +271,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -278,7 +284,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, "param": "dummy-param", @@ -288,8 +294,9 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_cypher_error(_verify_version_mock, driver): +def test_retrieval_query_cypher_error(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 3ce7c774..d20185b2 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType def test_vector_search_basic(): expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score" + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score" ) - result = get_search_query(SearchType.VECTOR) + result, params = get_search_query(SearchType.VECTOR) assert result.strip() == expected.strip() + assert params == {} def test_hybrid_search_basic(): @@ -41,28 +43,28 @@ def test_hybrid_search_basic(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node, score" ) - result = get_search_query(SearchType.HYBRID) + result, _ = get_search_query(SearchType.HYBRID) assert result.strip() == expected.strip() def test_vector_search_with_properties(): properties = ["name", "age"] expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) + result, _ = get_search_query(SearchType.VECTOR, return_properties=properties) assert result.strip() == expected.strip() def test_vector_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " + retrieval_query ) - result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -82,7 +84,7 @@ def test_hybrid_search_with_retrieval_query(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + retrieval_query ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -102,52 +104,5 @@ def test_hybrid_search_with_properties(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.HYBRID, return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_retrieval_query(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = retrieval_query - result = get_query_tail(retrieval_query=retrieval_query) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_properties(): - properties = ["name", "age"] - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail(return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_fallback(): - fallback = "HELLO" - expected = fallback - result = get_query_tail(fallback_return=fallback) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_all(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - properties = ["name", "age"] - fallback = "HELLO" - - expected = retrieval_query - result = get_query_tail( - retrieval_query=retrieval_query, - return_properties=properties, - fallback_return=fallback, - ) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_no_retrieval_query(): - properties = ["name", "age"] - fallback = "HELLO" - - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail( - return_properties=properties, - fallback_return=fallback, - ) + result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() From fb342a5f2e3acfa4164db17f144a8c405b3cb093 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:00:38 +0200 Subject: [PATCH 08/38] Ruff --- src/neo4j_genai/neo4j_queries.py | 33 +++++++++++---- src/neo4j_genai/retrievers/filters.py | 59 ++++++++++++++++----------- tests/unit/retrievers/test_filters.py | 20 ++++++--- tests/unit/retrievers/test_hybrid.py | 4 +- tests/unit/retrievers/test_vector.py | 24 ++++++++--- 5 files changed, 97 insertions(+), 43 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index e3cf4149..5897fa5a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -54,7 +54,12 @@ def _get_hybrid_query() -> str: ) -def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_filtered_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -64,15 +69,25 @@ def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embeddi embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return f"""{base_query} + return ( + f"""{base_query} AND ({where_filters}) {vector_query} - """, query_params + """, + query_params, + ) -def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: if filters: - return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return _get_filtered_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) return VECTOR_INDEX_QUERY, {} @@ -91,10 +106,14 @@ def get_search_query( query = _get_hybrid_query() params = {} elif search_type == SearchType.VECTOR: - query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + query, params = _get_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) else: raise ValueError(f"Search type is not supported: {search_type}") - query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + query_tail = _get_query_tail( + retrieval_query, return_properties, fallback_return="RETURN node, score" + ) return " ".join([query, query_tail]), params diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 358a92fc..0919c237 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -12,12 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Filters format: -{"property_name": "property_value"} - - -""" from typing import Any, Type from collections import Counter @@ -34,6 +28,7 @@ class Operator: - cleaned_value: a method to make sure the provided parameter values are consistent with the operator (e.g. LIKE operator only works with string values) """ + CYPHER_OPERATOR = None def __init__(self, node_alias=DEFAULT_NODE_ALIAS): @@ -96,7 +91,6 @@ def cleaned_value(self, value): class ILikeOperator(LikeOperator): - def lhs(self, field): return f"toLower({self.node_alias}.`{field}`)" @@ -139,9 +133,7 @@ def cleaned_value(self, value): LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(LOGICAL_OPERATORS) - .union({OPERATOR_BETWEEN}) + set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) ) @@ -174,7 +166,13 @@ def add(self, key, value): return param_name -def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: +def _single_condition_cypher( + field: str, + native_operator_class: Type[Operator], + value: Any, + param_store: ParameterStore, + node_alias: str, +) -> str: """Return Cypher for field operator value NB: the param_store argument is mutable, it will be updated in this function """ @@ -185,8 +183,10 @@ def _single_condition_cypher(field: str, native_operator_class: Type[Operator], def _handle_field_filter( - field: str, value: Any, param_store: ParameterStore, - node_alias: str = DEFAULT_NODE_ALIAS + field: str, + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Create a filter for a specific field. @@ -254,10 +254,14 @@ def _handle_field_filter( return query_snippet # all the other operators are handled through their own classes: native_op_class = COMPARISONS_TO_NATIVE[operator] - return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + return _single_condition_cypher( + field, native_op_class, filter_value, param_store, node_alias + ) -def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: +def _construct_metadata_filter( + filter: dict[str, Any], param_store: ParameterStore, node_alias: str +) -> str: """Construct a metadata filter. This is a recursive function parsing the filter dict Args: @@ -275,19 +279,21 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto raise ValueError() # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: - return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + return _construct_metadata_filter( + {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias + ) # The only operators allowed at the top level are $AND and $OR # First check if an operator or a field key, value = list(filter.items())[0] if not key.startswith("$"): # it's not an operator, must be a field - return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + return _handle_field_filter( + key, filter[key], param_store, node_alias=node_alias + ) # Here we handle the $and and $or operators if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") if key.lower() == OPERATOR_AND: cypher_operator = " AND " elif key.lower() == OPERATOR_OR: @@ -295,12 +301,17 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto else: raise ValueError(f"Unsupported filter {filter}") query = cypher_operator.join( - [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + [ + f"({ _construct_metadata_filter(el, param_store, node_alias)})" + for el in value + ] ) return query -def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: +def construct_metadata_filter( + filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS +) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: @@ -312,4 +323,6 @@ def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_ contains the query parameters """ param_store = ParameterStore() - return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params + return _construct_metadata_filter( + filter, param_store, node_alias=node_alias + ), param_store.params diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index 536f9491..fd562118 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -137,9 +137,17 @@ def test_filter_or_condition(): def test_filter_and_or_combined(): - filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} - query, params = construct_metadata_filter(filters) - assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + filters = { + "$and": [ + {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, + {"field_3": 11}, + ] + } + query, params = construct_metadata_filter(filters) + assert ( + query + == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} @@ -147,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 093364a6..79486835 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -206,7 +206,9 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.HYBRID, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 9be9da60..c3fd1ade 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -36,7 +36,9 @@ def test_vector_cypher_retriever_initialization(driver): @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -64,7 +66,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_ @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -96,7 +100,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_in @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_return_properties( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -180,7 +186,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_bad_results( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -226,7 +234,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, dr None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, @@ -271,7 +281,9 @@ def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, d None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, From f3c0daba5c877cdaa2740525123bed4de3764a5b Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:04:56 +0200 Subject: [PATCH 09/38] Back to the normal dimension size in e2e tests --- tests/e2e/conftest.py | 6 +++--- tests/e2e/test_vector_e2e.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 442a51a0..3176003a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -56,7 +56,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=10, + dimensions=1536, similarity_fn="euclidean", ) @@ -66,7 +66,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(10)] + vector = [random.random() for _ in range(1536)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -88,6 +88,6 @@ def random_str(n: int) -> str: "id": str(uuid.uuid4()), "i": i, "vector": vector, - "authorName": random_str(10), + "authorName": random_str(1536), } driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index baeae191..608dd4d0 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -113,7 +113,7 @@ def test_vector_retriever_filters(driver): top_k = 2 results = retriever.search( - query_vector=[1.0 for _ in range(10)], + query_vector=[1.0 for _ in range(1536)], filters={"int_property": {"$gt": 2}}, top_k=top_k, ) From bea058dee42b5a9fbf562c45906d81d56f63d568 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:21:49 +0200 Subject: [PATCH 10/38] Improved docstrings + include an example --- examples/vector_search_with_filters.py | 72 ++++++++++++++++++++++++++ src/neo4j_genai/neo4j_queries.py | 59 +++++++++++++++++++-- src/neo4j_genai/retrievers/filters.py | 37 +++++++++---- 3 files changed, 155 insertions(+), 13 deletions(-) create mode 100644 examples/vector_search_with_filters.py diff --git a/examples/vector_search_with_filters.py b/examples/vector_search_with_filters.py new file mode 100644 index 00000000..bf5fa444 --- /dev/null +++ b/examples/vector_search_with_filters.py @@ -0,0 +1,72 @@ +from neo4j import GraphDatabase +from neo4j_genai import VectorRetriever + +import random +import string +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index + + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(DIMENSION)] + + +# Generate random strings +def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME, embedder) + +# Upsert the query +vector = [random.random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $id, " + " doc.short_text_property = toString($id)" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" +) +parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), +} +driver.execute_query(insert_query, parameters) + +# Perform the search +query_text = "Find me a book about Fremen" +print( + retriever.search( + query_text=query_text, top_k=1, filters={"int_property": {"$gt": 100}} + ) +) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 5897fa5a..52bbf332 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -60,6 +60,18 @@ def _get_filtered_vector_query( embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build Cypher query for vector search with filters + Uses exact KNN. + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + """ where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -71,19 +83,31 @@ def _get_filtered_vector_query( query_params["embedding_dimension"] = embedding_dimension return ( f"""{base_query} - AND ({where_filters}) - {vector_query} + AND ({where_filters}) + {vector_query} """, query_params, ) def _get_vector_query( - filters: dict[str, Any], + filters: Optional[dict[str, Any]], node_label: str, embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build the vector query with or without filters + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if filters: return _get_filtered_vector_query( filters, node_label, embedding_node_property, embedding_dimension @@ -100,6 +124,23 @@ def get_search_query( embedding_dimension: Optional[int] = None, filters: Optional[dict[str, Any]] = None, ) -> tuple[str, dict[str, Any]]: + """Build the search query, including pre-filtering if needed, and return clause. + + Args + search_type: Search type we want to search for: + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if search_type == SearchType.HYBRID: if filters: raise Exception("Filters is not supported with Hybrid Search") @@ -122,6 +163,18 @@ def _get_query_tail( return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, ) -> str: + """Build the RETURN statement after the search is performed + + Args + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + fallback_return (str): the fallback return statement to use to retrieve the search results + + Returns: + str: the RETURN statement + """ if retrieval_query: return retrieval_query if return_properties: diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 0919c237..fc052fa7 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -148,8 +148,11 @@ def __init__(self): self.params = {} def _get_params_name(self, key="param"): - """NB: the counter parameter is there in purpose, will be modified in the function - to remember the count of each parameter + """Find parameter name so that param names are unique. + This function adds a suffix to the key corresponding to the number + of times the key have been used in the query. + E.g. + node.age >= $param_0 AND node.age <= $param_1 :param p: :param counter: @@ -161,6 +164,9 @@ def _get_params_name(self, key="param"): return param_name def add(self, key, value): + """This function adds a new parameter to the param dict. + It returns the name of the parameter to be used as a placeholder + in the cypher query, e.g. $param_0""" param_name = self._get_params_name() self.params[param_name] = value return param_name @@ -173,10 +179,21 @@ def _single_condition_cypher( param_store: ParameterStore, node_alias: str, ) -> str: - """Return Cypher for field operator value + """Return Cypher for field operator value. + + Args: + field: the name of the field being filtered + native_operator_class: the operator class that will be used to generate + the Cypher query + value: filtered value + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query + Returns: + str: the Cypher condition, e.g. node.`property` = $param_0 + NB: the param_store argument is mutable, it will be updated in this function """ - native_op = native_operator_class() + native_op = native_operator_class(node_alias=node_alias) param_name = param_store.add(field, native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -196,11 +213,11 @@ def _handle_field_filter( If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by - param_store: - node_alias: + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns - - Cypher filter snippet* + str: Cypher filter snippet NB: the param_store argument is mutable, it will be updated in this function """ @@ -266,11 +283,11 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. - param_store: A ParamStore object that will deal with parameter naming and saving along the process - node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns: - str + str: the Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ From 90e319151d63cee6a97ab21d242e3a930d591892 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 09:20:33 +0200 Subject: [PATCH 11/38] Re-add tests for the _get_query_tail function (deleted by mistake) --- tests/unit/test_neo4j_queries.py | 49 +++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index d20185b2..0d420c51 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType @@ -106,3 +106,50 @@ def test_hybrid_search_with_properties(): ) result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() + + +def test_get_query_tail_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = retrieval_query + result = _get_query_tail(retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_properties(): + properties = ["name", "age"] + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail(return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_fallback(): + fallback = "HELLO" + expected = fallback + result = _get_query_tail(fallback_return=fallback) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_all(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + properties = ["name", "age"] + fallback = "HELLO" + + expected = retrieval_query + result = _get_query_tail( + retrieval_query=retrieval_query, + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_no_retrieval_query(): + properties = ["name", "age"] + fallback = "HELLO" + + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail( + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() From d52da87f663e10edb7930b1e52b6fceece4bc048 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 13:26:35 +0200 Subject: [PATCH 12/38] Update docstrings, move filters file, rename function --- src/neo4j_genai/{retrievers => }/filters.py | 34 ++++++++-------- src/neo4j_genai/neo4j_queries.py | 4 +- tests/unit/retrievers/test_filters.py | 44 ++++++++++----------- 3 files changed, 42 insertions(+), 40 deletions(-) rename src/neo4j_genai/{retrievers => }/filters.py (92%) diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/filters.py similarity index 92% rename from src/neo4j_genai/retrievers/filters.py rename to src/neo4j_genai/filters.py index fc052fa7..16699cc2 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/filters.py @@ -154,9 +154,10 @@ def _get_params_name(self, key="param"): E.g. node.age >= $param_0 AND node.age <= $param_1 - :param p: - :param counter: - :return: + Args: + key (str): The prefix for the parameter name + Returns: + The full unique parameter name """ # key = slugify(key.replace(".", "_"), separator="_") param_name = f"{key}_{self._counter[key]}" @@ -182,14 +183,14 @@ def _single_condition_cypher( """Return Cypher for field operator value. Args: - field: the name of the field being filtered - native_operator_class: the operator class that will be used to generate + field: The name of the field being filtered + native_operator_class: The operator class that will be used to generate the Cypher query value: filtered value param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher condition, e.g. node.`property` = $param_0 + str: The Cypher condition, e.g. node.`property` = $param_0 NB: the param_store argument is mutable, it will be updated in this function """ @@ -208,13 +209,13 @@ def _handle_field_filter( """Create a filter for a specific field. Args: - field: name of field - value: value to filter + field: Name of field + value: Value to filter If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns str: Cypher filter snippet @@ -284,16 +285,16 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher WHERE clause + str: The Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ if not isinstance(filter, dict): - raise ValueError() + raise ValueError(f"Filter must be a dictionary, received {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -326,14 +327,15 @@ def _construct_metadata_filter( return query -def construct_metadata_filter( +def get_metadata_filter( filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS ) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: - filter: a dict of filters - node_alias: the node the filters must be applied on + filter (dict): The filters to be converted to Cypher + node_alias (str): The alias of node the filters must be applied on + in the Cypher query Return: A tuple of str, dict where the string is the cypher query and the dict diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 52bbf332..014ebb4a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -15,7 +15,7 @@ from typing import Optional, Any from neo4j_genai.types import SearchType -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter VECTOR_INDEX_QUERY = ( @@ -72,7 +72,7 @@ def _get_filtered_vector_query( Returns: tuple[str, dict[str, Any]]: query and parameters """ - where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + where_filters, query_params = get_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, embedding_node_property=embedding_node_property, diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index fd562118..b6eb0e63 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -14,124 +14,124 @@ # limitations under the License. import pytest -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter def test_filter_single_field_string(): filters = {"field": "string_value"} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_single_field_int(): filters = {"field": 28} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": 28} def test_filter_single_field_bool(): filters = {"field": False} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": False} def test_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <> $param_0" assert params == {"param_0": "string_value"} def test_filter_lt_operator(): filters = {"field": {"$lt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` < $param_0" assert params == {"param_0": 1} def test_filter_gt_operator(): filters = {"field": {"$gt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` > $param_0" assert params == {"param_0": 1} def test_filter_lte_operator(): filters = {"field": {"$lte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <= $param_0" assert params == {"param_0": 1} def test_filter_gte_operator(): filters = {"field": {"$gte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` >= $param_0" assert params == {"param_0": 1} def test_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` NOT IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_like_operator(): filters = {"field": {"$like": "some_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` CONTAINS $param_0" assert params == {"param_0": "some_value"} def test_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "toLower(node.`field`) CONTAINS $param_0" assert params == {"param_0": "some value"} def test_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "$param_0 <= node.`field` <= $param_1" assert params == {"param_0": 0, "param_1": 1} def test_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} @@ -143,7 +143,7 @@ def test_filter_and_or_combined(): {"field_3": 11}, ] } - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert ( query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" @@ -155,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) From 0475f3868dc6b3d4e1767cface32ee21fb1002cb Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:13:52 +0200 Subject: [PATCH 13/38] More unit tests --- src/neo4j_genai/filters.py | 10 +- tests/unit/retrievers/test_filters.py | 332 +++++++++++++++++++++++++- 2 files changed, 334 insertions(+), 8 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 16699cc2..2f301ed6 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -71,9 +71,7 @@ class InOperator(Operator): def cleaned_value(self, value): for val in value: if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) + raise ValueError(f"Unsupported type: {type(val)} for value: {val}") return value @@ -178,7 +176,7 @@ def _single_condition_cypher( native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, - node_alias: str, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Return Cypher for field operator value. @@ -263,6 +261,10 @@ def _handle_field_filter( # special case for the BETWEEN operator that requires # two tests (lower_bound <= value <= higher_bound) if operator == OPERATOR_BETWEEN: + if len(filter_value) != 2: + raise ValueError( + f"Expected lower and upper bounds in a list, got {filter_value}" + ) low, high = filter_value param_name_low = param_store.add(field, low) param_name_high = param_store.add(field, high) diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index b6eb0e63..0124eefe 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -12,9 +12,324 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest -from neo4j_genai.filters import get_metadata_filter +from neo4j_genai.filters import ( + get_metadata_filter, + _single_condition_cypher, + _handle_field_filter, + _construct_metadata_filter, + EqOperator, + NeqOperator, + LtOperator, + GtOperator, + LteOperator, + GteOperator, + InOperator, + NinOperator, + LikeOperator, + ILikeOperator, + ParameterStore, +) + + +@pytest.fixture(scope="function") +def param_store_empty(): + return ParameterStore() + + +def test_param_store(): + ps = ParameterStore() + assert ps.params == {} + ps.add("", 1) + assert ps.params == {"param_0": 1} + ps.add("", "some value") + assert ps.params == {"param_0": 1, "param_1": "some value"} + + +def test_single_condition_cypher_eq(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_eq_node_alias(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", node_alias="n", param_store=param_store_empty + ) + assert generated == "n.`field` = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_neq(param_store_empty): + generated = _single_condition_cypher( + "field", NeqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` <> $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_lt(param_store_empty): + generated = _single_condition_cypher( + "field", LtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` < $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gt(param_store_empty): + generated = _single_condition_cypher( + "field", GtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` > $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_lte(param_store_empty): + generated = _single_condition_cypher( + "field", LteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` <= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gte(param_store_empty): + generated = _single_condition_cypher( + "field", GteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` >= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_in_int(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, [1, 2, 3], param_store=param_store_empty + ) + assert generated == "node.`field` IN $param_0" + assert param_store_empty.params == {"param_0": [1, 2, 3]} + + +def test_single_condition_cypher_in_str(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.`field` IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_in_invalid_type(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", + InOperator, + [ + {"my_tuple"}, + ], + param_store=param_store_empty, + ) + assert "Unsupported type: " in str(excinfo) + + +def test_single_condition_cypher_nin(param_store_empty): + generated = _single_condition_cypher( + "field", NinOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.`field` NOT IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_like(param_store_empty): + generated = _single_condition_cypher( + "field", LikeOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_ilike(param_store_empty): + generated = _single_condition_cypher( + "field", ILikeOperator, "My Value", param_store=param_store_empty + ) + assert generated == "toLower(node.`field`) CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "my value"} + + +def test_single_condition_cypher_like_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", ILikeOperator, 1, param_store=param_store_empty + ) + assert "Expected string value, got " in str(excinfo) + + +def test_handle_field_filter_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter(1, "value", param_store=param_store_empty) + assert "Field should be a string but got: with value: 1" in str( + excinfo + ) + + +def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter("$field_name", "value", param_store=param_store_empty) + assert ( + "Invalid filter condition. Expected a field but got an operator: $field_name" + in str(excinfo) + ) + + +def test_handle_field_filter_bad_field_name(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter("bad+field?name", "value", param_store=param_store_empty) + assert "Invalid field name: bad+field?name. Expected a valid identifier." in str( + excinfo + ) + + +def test_handle_field_filter_bad_value(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={"operator1": "value1", "operator2": "value2"}, + param_store=param_store_empty, + ) + assert "Invalid filter condition" in str(excinfo) + + +def test_handle_field_filter_bad_operator_name(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", value={"$invalid": "value"}, param_store=param_store_empty + ) + assert "Invalid operator: $invalid" in str(excinfo) + + +def test_handle_field_filter_operator_between(param_store_empty): + generated = _handle_field_filter( + "field", value={"$between": [0, 1]}, param_store=param_store_empty + ) + assert generated == "$param_0 <= node.`field` <= $param_1" + assert param_store_empty.params == {"param_0": 0, "param_1": 1} + + +def test_handle_field_filter_operator_between_not_enough_parameters(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={ + "$between": [ + 0, + ] + }, + param_store=param_store_empty, + ) + assert "Expected lower and upper bounds in a list, got [0]" in str(excinfo) + + +@patch("neo4j_genai.filters._single_condition_cypher", return_value="condition") +def test_handle_field_filter_implicit_eq( + _single_condition_cypher_mocked, param_store_empty +): + generated = _handle_field_filter( + "field", value="some_value", param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + assert generated == "condition" + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_eq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$eq": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_neq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ne": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NeqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_in(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$in": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", InOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_nin(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$nin": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NinOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_like(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$like": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LikeOperator, "value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ilike": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", ILikeOperator, "value", param_store_empty, "node" + ) def test_filter_single_field_string(): @@ -129,6 +444,15 @@ def test_filter_explicit_and_condition(): assert params == {"param_0": "string_value", "param_1": True} +def test_filter_explicit_and_condition_with_operator(): + filters = { + "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] + } + query, params = get_metadata_filter(filters) + assert query == "(node.`field_1` <> $param_0) AND (node.`field_2` IN $param_1)" + assert params == {"param_0": "string_value", "param_1": [1, 2]} + + def test_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) @@ -144,9 +468,9 @@ def test_filter_and_or_combined(): ] } query, params = get_metadata_filter(filters) - assert ( - query - == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + assert query == ( + "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) " + "AND (node.`field_3` = $param_2)" ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} From 24d29d4607384fad3118e2b0008661f7733cd2f1 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:47:17 +0200 Subject: [PATCH 14/38] More unit tests for filters --- src/neo4j_genai/filters.py | 4 +- tests/unit/{retrievers => }/test_filters.py | 95 ++++++++++++++++----- 2 files changed, 74 insertions(+), 25 deletions(-) rename tests/unit/{retrievers => }/test_filters.py (81%) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 2f301ed6..ebb062d0 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -296,7 +296,7 @@ def _construct_metadata_filter( """ if not isinstance(filter, dict): - raise ValueError(f"Filter must be a dictionary, received {type(filter)}") + raise ValueError(f"Filter must be a dictionary, got {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -319,7 +319,7 @@ def _construct_metadata_filter( elif key.lower() == OPERATOR_OR: cypher_operator = " OR " else: - raise ValueError(f"Unsupported filter {filter}") + raise ValueError(f"Unsupported operator: {key}") query = cypher_operator.join( [ f"({ _construct_metadata_filter(el, param_store, node_alias)})" diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/test_filters.py similarity index 81% rename from tests/unit/retrievers/test_filters.py rename to tests/unit/test_filters.py index 0124eefe..2fba20b9 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/test_filters.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest.mock import patch, call import pytest @@ -332,119 +332,168 @@ def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_ ) -def test_filter_single_field_string(): +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_filter_is_not_a_dict(_handle_field_filter_mock, param_store_empty): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter([], param_store_empty, node_alias="n") + assert "Filter must be a dictionary, got " in str(excinfo) + + +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_store_empty): + _construct_metadata_filter({"field": "value"}, param_store_empty, node_alias="n") + _handle_field_filter_mock.assert_called_once_with( + "field", "value", param_store_empty, node_alias="n" + ) + + +@patch("neo4j_genai.filters._construct_metadata_filter") +def test_construct_metadata_filter_implicit_and(_construct_metadata_filter_mock, param_store_empty): + _construct_metadata_filter({"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, "n"), + ]) + + +@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) +def test_construct_metadata_filter_explicit_and(_construct_metadata_filter_mock, param_store_empty): + generated = _construct_metadata_filter({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n") + ]) + assert generated == "(filter1) AND (filter2)" + + +@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) +def test_construct_metadata_filter_or(_construct_metadata_filter_mock, param_store_empty): + generated = _construct_metadata_filter({"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n") + ]) + assert generated == "(filter1) OR (filter2)" + + +def test_construct_metadata_filter_invalid_operator(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter({"$invalid": [{}, {}]}, param_store_empty, node_alias="n") + assert "Unsupported operator: $invalid" in str(excinfo) + + +def test_get_metadata_filter_single_field_string(): filters = {"field": "string_value"} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} -def test_filter_single_field_int(): +def test_get_metadata_filter_single_field_int(): filters = {"field": 28} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": 28} -def test_filter_single_field_bool(): +def test_get_metadata_filter_single_field_bool(): filters = {"field": False} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": False} -def test_filter_explicit_eq_operator(): +def test_get_metadata_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} -def test_filter_neq_operator(): +def test_get_metadata_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` <> $param_0" assert params == {"param_0": "string_value"} -def test_filter_lt_operator(): +def test_get_metadata_filter_lt_operator(): filters = {"field": {"$lt": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` < $param_0" assert params == {"param_0": 1} -def test_filter_gt_operator(): +def test_get_metadata_filter_gt_operator(): filters = {"field": {"$gt": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` > $param_0" assert params == {"param_0": 1} -def test_filter_lte_operator(): +def test_get_metadata_filter_lte_operator(): filters = {"field": {"$lte": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` <= $param_0" assert params == {"param_0": 1} -def test_filter_gte_operator(): +def test_get_metadata_filter_gte_operator(): filters = {"field": {"$gte": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` >= $param_0" assert params == {"param_0": 1} -def test_filter_in_operator(): +def test_get_metadata_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} query, params = get_metadata_filter(filters) assert query == "node.`field` IN $param_0" assert params == {"param_0": ["a", "b"]} -def test_filter_not_in_operator(): +def test_get_metadata_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} query, params = get_metadata_filter(filters) assert query == "node.`field` NOT IN $param_0" assert params == {"param_0": ["a", "b"]} -def test_filter_like_operator(): +def test_get_metadata_filter_like_operator(): filters = {"field": {"$like": "some_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` CONTAINS $param_0" assert params == {"param_0": "some_value"} -def test_filter_ilike_operator(): +def test_get_metadata_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} query, params = get_metadata_filter(filters) assert query == "toLower(node.`field`) CONTAINS $param_0" assert params == {"param_0": "some value"} -def test_filter_between_operator(): +def test_get_metadata_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} query, params = get_metadata_filter(filters) assert query == "$param_0 <= node.`field` <= $param_1" assert params == {"param_0": 0, "param_1": 1} -def test_filter_implicit_and_condition(): +def test_get_metadata_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_explicit_and_condition(): +def test_get_metadata_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_explicit_and_condition_with_operator(): +def test_get_metadata_filter_explicit_and_condition_with_operator(): filters = { "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] } @@ -453,14 +502,14 @@ def test_filter_explicit_and_condition_with_operator(): assert params == {"param_0": "string_value", "param_1": [1, 2]} -def test_filter_or_condition(): +def test_get_metadata_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_and_or_combined(): +def test_get_metadata_filter_and_or_combined(): filters = { "$and": [ {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, @@ -476,19 +525,19 @@ def test_filter_and_or_combined(): # now testing bad filters -def test_field_name_with_dollar_sign(): +def test_get_metadata_filter_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): get_metadata_filter(filters) -def test_and_no_list(): +def test_get_metadata_filter_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): get_metadata_filter(filters) -def test_unsupported_operator(): +def test_get_metadata_filter_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): get_metadata_filter(filters) From 0a4711efc8b4478b6ccc38ea8badcdaf3da60b15 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:57:28 +0200 Subject: [PATCH 15/38] Increase test coverage for queries --- src/neo4j_genai/neo4j_queries.py | 10 +--- tests/unit/test_filters.py | 84 +++++++++++++++++++++++--------- tests/unit/test_neo4j_queries.py | 50 +++++++++++++++++++ 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 014ebb4a..e974a047 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -81,13 +81,7 @@ def _get_filtered_vector_query( embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return ( - f"""{base_query} - AND ({where_filters}) - {vector_query} - """, - query_params, - ) + return f"{base_query} AND ({where_filters}) {vector_query}", query_params def _get_vector_query( @@ -155,7 +149,7 @@ def get_search_query( query_tail = _get_query_tail( retrieval_query, return_properties, fallback_return="RETURN node, score" ) - return " ".join([query, query_tail]), params + return f"{query} {query_tail}", params def _get_query_tail( diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 2fba20b9..8dff3570 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -333,14 +333,18 @@ def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_ @patch("neo4j_genai.filters._handle_field_filter") -def test_construct_metadata_filter_filter_is_not_a_dict(_handle_field_filter_mock, param_store_empty): +def test_construct_metadata_filter_filter_is_not_a_dict( + _handle_field_filter_mock, param_store_empty +): with pytest.raises(ValueError) as excinfo: _construct_metadata_filter([], param_store_empty, node_alias="n") assert "Filter must be a dictionary, got " in str(excinfo) @patch("neo4j_genai.filters._handle_field_filter") -def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_store_empty): +def test_construct_metadata_filter_no_operator( + _handle_field_filter_mock, param_store_empty +): _construct_metadata_filter({"field": "value"}, param_store_empty, node_alias="n") _handle_field_filter_mock.assert_called_once_with( "field", "value", param_store_empty, node_alias="n" @@ -348,36 +352,68 @@ def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_ @patch("neo4j_genai.filters._construct_metadata_filter") -def test_construct_metadata_filter_implicit_and(_construct_metadata_filter_mock, param_store_empty): - _construct_metadata_filter({"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, "n"), - ]) - - -@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) -def test_construct_metadata_filter_explicit_and(_construct_metadata_filter_mock, param_store_empty): - generated = _construct_metadata_filter({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"field_1": "value_1"}, param_store_empty, "n"), - call({"field_2": "value_2"}, param_store_empty, "n") - ]) +def test_construct_metadata_filter_implicit_and( + _construct_metadata_filter_mock, param_store_empty +): + _construct_metadata_filter( + {"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n" + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + "n", + ), + ] + ) + + +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_explicit_and( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) assert generated == "(filter1) AND (filter2)" -@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) -def test_construct_metadata_filter_or(_construct_metadata_filter_mock, param_store_empty): - generated = _construct_metadata_filter({"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"field_1": "value_1"}, param_store_empty, "n"), - call({"field_2": "value_2"}, param_store_empty, "n") - ]) +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_or( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) assert generated == "(filter1) OR (filter2)" def test_construct_metadata_filter_invalid_operator(param_store_empty): with pytest.raises(ValueError) as excinfo: - _construct_metadata_filter({"$invalid": [{}, {}]}, param_store_empty, node_alias="n") + _construct_metadata_filter( + {"$invalid": [{}, {}]}, param_store_empty, node_alias="n" + ) assert "Unsupported operator: $invalid" in str(excinfo) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 0d420c51..0ef2b68e 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType @@ -68,6 +69,55 @@ def test_vector_search_with_retrieval_query(): assert result.strip() == expected.strip() +@patch("neo4j_genai.neo4j_queries.get_metadata_filter", return_value=["True", {}]) +def test_vector_search_with_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) + assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1} + + +@patch( + "neo4j_genai.neo4j_queries.get_metadata_filter", + return_value=["True", {"param": "value"}], +) +def test_vector_search_with_params_from_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) + assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1, "param": "value"} + + def test_hybrid_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( From 0fd442f1c7f246ab1b24fd1036fdeb816d4ba59e Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 19:25:48 +0200 Subject: [PATCH 16/38] Simplification, formatting --- src/neo4j_genai/filters.py | 19 ++++++++----------- src/neo4j_genai/retrievers/base.py | 12 +++++++----- tests/unit/test_filters.py | 19 +++++++++---------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index ebb062d0..100b9bd0 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -35,7 +35,8 @@ def __init__(self, node_alias=DEFAULT_NODE_ALIAS): self.node_alias = node_alias def lhs(self, field): - return f"{self.node_alias}.`{field}`" + escaped_field = field.replace("`", "``") + return f"{self.node_alias}.`{escaped_field}`" def cleaned_value(self, value): return value @@ -145,7 +146,7 @@ def __init__(self): self._counter = Counter() self.params = {} - def _get_params_name(self, key="param"): + def _get_params_name(self): """Find parameter name so that param names are unique. This function adds a suffix to the key corresponding to the number of times the key have been used in the query. @@ -157,12 +158,12 @@ def _get_params_name(self, key="param"): Returns: The full unique parameter name """ - # key = slugify(key.replace(".", "_"), separator="_") + key = "param" param_name = f"{key}_{self._counter[key]}" self._counter[key] += 1 return param_name - def add(self, key, value): + def add(self, value): """This function adds a new parameter to the param dict. It returns the name of the parameter to be used as a placeholder in the cypher query, e.g. $param_0""" @@ -193,7 +194,7 @@ def _single_condition_cypher( NB: the param_store argument is mutable, it will be updated in this function """ native_op = native_operator_class(node_alias=node_alias) - param_name = param_store.add(field, native_op.cleaned_value(value)) + param_name = param_store.add(native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -232,10 +233,6 @@ def _handle_field_filter( f"{field}" ) - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") - if isinstance(value, dict): # This is a filter specification e.g. {"$gte": 0} if len(value) != 1: @@ -266,8 +263,8 @@ def _handle_field_filter( f"Expected lower and upper bounds in a list, got {filter_value}" ) low, high = filter_value - param_name_low = param_store.add(field, low) - param_name_high = param_store.add(field, high) + param_name_low = param_store.add(low) + param_name_high = param_store.add(high) query_snippet = ( f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" ) diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index 24257429..acad184b 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -60,11 +60,13 @@ def search(self, *args, **kwargs) -> Any: def _fetch_index_infos(self): """Fetch the node label and embedding property from the index definition""" - query = """SHOW VECTOR INDEXES -YIELD name, labelsOrTypes, properties, options -WHERE name = $index_name -RETURN labelsOrTypes as labels, properties, options.indexConfig.`vector.dimensions` as dimensions - """ + query = ( + "SHOW VECTOR INDEXES " + "YIELD name, labelsOrTypes, properties, options " + "WHERE name = $index_name " + "RETURN labelsOrTypes as labels, properties, " + "options.indexConfig.`vector.dimensions` as dimensions" + ) result = self.driver.execute_query(query, {"index_name": self.index_name}) try: result = result.records[0] diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 8dff3570..9979846e 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -43,9 +43,9 @@ def param_store_empty(): def test_param_store(): ps = ParameterStore() assert ps.params == {} - ps.add("", 1) + ps.add(1) assert ps.params == {"param_0": 1} - ps.add("", "some value") + ps.add("some value") assert ps.params == {"param_0": 1, "param_1": "some value"} @@ -166,6 +166,13 @@ def test_single_condition_cypher_like_not_a_string(param_store_empty): assert "Expected string value, got " in str(excinfo) +def test_single_condition_cypher_escaped_field_name(param_store_empty): + generated = _single_condition_cypher( + "na`me", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`na``me` = $param_0" + + def test_handle_field_filter_not_a_string(param_store_empty): with pytest.raises(ValueError) as excinfo: _handle_field_filter(1, "value", param_store=param_store_empty) @@ -183,14 +190,6 @@ def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): ) -def test_handle_field_filter_bad_field_name(param_store_empty): - with pytest.raises(ValueError) as excinfo: - _handle_field_filter("bad+field?name", "value", param_store=param_store_empty) - assert "Invalid field name: bad+field?name. Expected a valid identifier." in str( - excinfo - ) - - def test_handle_field_filter_bad_value(param_store_empty): with pytest.raises(ValueError) as excinfo: _handle_field_filter( From 87e9ded0030d04b497d1dd51195222cf6eeba8f4 Mon Sep 17 00:00:00 2001 From: willtai Date: Mon, 13 May 2024 10:22:14 +0100 Subject: [PATCH 17/38] Add try catch for create_index and rename imports of neo4j (#30) --- src/neo4j_genai/indexes.py | 83 +++++++++++++++++++--------- src/neo4j_genai/retrievers/base.py | 4 +- src/neo4j_genai/retrievers/hybrid.py | 14 ++--- src/neo4j_genai/retrievers/vector.py | 10 ++-- src/neo4j_genai/types.py | 4 +- tests/e2e/conftest.py | 10 +++- tests/e2e/test_hybrid_e2e.py | 12 ++-- tests/unit/conftest.py | 4 +- tests/unit/test_indexes.py | 34 ++++++++++-- 9 files changed, 118 insertions(+), 57 deletions(-) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 132cc144..32ace578 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j import Driver +import neo4j from pydantic import ValidationError from .types import VectorIndexModel, FulltextIndexModel +import logging + + +logger = logging.getLogger(__name__) def create_vector_index( - driver: Driver, + driver: neo4j.Driver, name: str, label: str, property: str, @@ -32,8 +36,11 @@ def create_vector_index( See Cypher manual on [Create vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#indexes-vector-create) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. property (str): The property key of a node which contains embedding values. @@ -43,6 +50,7 @@ def create_vector_index( Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of vector index fails. """ try: VectorIndexModel( @@ -58,17 +66,23 @@ def create_vector_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_vector_index {str(e)}") - query = ( - f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" - ) - driver.execute_query( - query, {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn} - ) + try: + query = ( + f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " + "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" + ) + logger.info(f"Creating vector index named '{name}'") + driver.execute_query( + query, + {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, + ) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j vector index creation failed {e}") + raise def create_fulltext_index( - driver: Driver, name: str, label: str, node_properties: list[str] + driver: neo4j.Driver, name: str, label: str, node_properties: list[str] ) -> None: """ This method constructs a Cypher query and executes it @@ -76,14 +90,18 @@ def create_fulltext_index( See Cypher manual on [Create fulltext index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/full-text-indexes/#create-full-text-indexes) + Important: This operation will fail if an index with the same name already exists. + Ensure that the index name provided is unique within the database context. + Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The unique name of the index. label (str): The node label to be indexed. node_properties (list[str]): The node properties to create the fulltext index on. Raises: ValueError: If validation of the input arguments fail. + neo4j.exceptions.ClientError: If creation of fulltext index fails. """ try: FulltextIndexModel( @@ -97,26 +115,39 @@ def create_fulltext_index( except ValidationError as e: raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") - query = ( - "CREATE FULLTEXT INDEX $name " - f"FOR (n:`{label}`) ON EACH " - f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" - ) - driver.execute_query(query, {"name": name}) + try: + query = ( + "CREATE FULLTEXT INDEX $name " + f"FOR (n:`{label}`) ON EACH " + f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]" + ) + logger.info(f"Creating fulltext index named '{name}'") + driver.execute_query(query, {"name": name}) + except neo4j.exceptions.ClientError as e: + logger.error(f"Neo4j fulltext index creation failed {e}") + raise -def drop_index(driver: Driver, name: str) -> None: +def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None: """ This method constructs a Cypher query and executes it - to drop a vector index in Neo4j. + to drop an index in Neo4j, if the index exists. See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) Args: - driver (Driver): Neo4j Python driver instance. + driver (neo4j.Driver): Neo4j Python driver instance. name (str): The name of the index to delete. + + Raises: + neo4j.exceptions.ClientError: If dropping of index fails. """ - query = "DROP INDEX $name IF EXISTS" - parameters = { - "name": name, - } - driver.execute_query(query, parameters) + try: + query = "DROP INDEX $name IF EXISTS" + parameters = { + "name": name, + } + logger.info(f"Dropping index named '{name}'") + driver.execute_query(query, parameters) + except neo4j.exceptions.ClientError as e: + logger.error(f"Dropping Neo4j index failed {e}") + raise diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index dc483eb6..c3ca671d 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Any -from neo4j import Driver +import neo4j class Retriever(ABC): @@ -23,7 +23,7 @@ class Retriever(ABC): Abstract class for Neo4j retrievers """ - def __init__(self, driver: Driver): + def __init__(self, driver: neo4j.Driver): self.driver = driver self._verify_version() diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 0690555a..a2163b34 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Record, Driver +import neo4j from pydantic import ValidationError from neo4j_genai.embedder import Embedder @@ -29,7 +29,7 @@ class HybridRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, embedder: Optional[Embedder] = None, @@ -46,7 +46,7 @@ def search( query_text: str, query_vector: Optional[list[float]] = None, top_k: int = 5, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -63,7 +63,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridSearchModel( @@ -96,7 +96,7 @@ def search( class HybridCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, vector_index_name: str, fulltext_index_name: str, retrieval_query: str, @@ -114,7 +114,7 @@ def search( query_vector: Optional[list[float]] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. If query_vector is provided, then it will be preferred over the embedded query_text @@ -132,7 +132,7 @@ def search( ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = HybridCypherSearchModel( diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 954cd04e..af3a6068 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Optional, Any -from neo4j import Driver, Record +import neo4j from neo4j_genai.retrievers.base import Retriever from pydantic import ValidationError @@ -39,7 +39,7 @@ class VectorRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, embedder: Optional[Embedder] = None, return_properties: Optional[list[str]] = None, @@ -120,7 +120,7 @@ class VectorCypherRetriever(Retriever): def __init__( self, - driver: Driver, + driver: neo4j.Driver, index_name: str, retrieval_query: str, embedder: Optional[Embedder] = None, @@ -136,7 +136,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Record]: + ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -154,7 +154,7 @@ def search( ValueError: If no embedder is provided. Returns: - list[Record]: The results of the search query + list[neo4j.Record]: The results of the search query """ try: validated_data = VectorCypherSearchModel( diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 67a31175..fc9a3e5a 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator, field_validator -from neo4j import Driver +import neo4j class VectorSearchRecord(BaseModel): @@ -28,7 +28,7 @@ class IndexModel(BaseModel): @field_validator("driver") def check_driver_is_valid(cls, v): - if not isinstance(v, Driver): + if not isinstance(v, neo4j.Driver): raise ValueError("driver must be an instance of neo4j.Driver") return v diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 64cd6504..0ce0c2c8 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -19,7 +19,11 @@ import pytest from neo4j import GraphDatabase from neo4j_genai.embedder import Embedder -from neo4j_genai.indexes import drop_index, create_vector_index, create_fulltext_index +from neo4j_genai.indexes import ( + drop_index_if_exists, + create_vector_index, + create_fulltext_index, +) @pytest.fixture(scope="module") @@ -47,8 +51,8 @@ def setup_neo4j(driver): # Delete data and drop indexes to prevent data leakage driver.execute_query("MATCH (n) DETACH DELETE n") - drop_index(driver, vector_index_name) - drop_index(driver, fulltext_index_name) + drop_index_if_exists(driver, vector_index_name) + drop_index_if_exists(driver, fulltext_index_name) # Create a vector index create_vector_index( diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py index f8f54466..3ba48c62 100644 --- a/tests/e2e/test_hybrid_e2e.py +++ b/tests/e2e/test_hybrid_e2e.py @@ -16,7 +16,7 @@ import pytest -from neo4j import Record +import neo4j from neo4j_genai import ( HybridRetriever, @@ -36,7 +36,7 @@ def test_hybrid_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -58,7 +58,7 @@ def test_hybrid_cypher_retriever_search_text(driver, custom_embedder): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -80,7 +80,7 @@ def test_hybrid_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) @pytest.mark.usefixtures("setup_neo4j") @@ -105,7 +105,7 @@ def test_hybrid_cypher_retriever_search_vector(driver): assert isinstance(results, list) assert len(results) == 5 for record in results: - assert isinstance(record, Record) + assert isinstance(record, neo4j.Record) assert "author.name" in record.keys() @@ -129,4 +129,4 @@ def test_hybrid_retriever_return_properties(driver): assert isinstance(results, list) assert len(results) == 5 for result in results: - assert isinstance(result, Record) + assert isinstance(result, neo4j.Record) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b22e58fc..75e0419f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,14 +14,14 @@ # limitations under the License. import pytest +import neo4j from neo4j_genai import VectorRetriever, VectorCypherRetriever, HybridRetriever -from neo4j import Driver from unittest.mock import MagicMock, patch @pytest.fixture(scope="function") def driver(): - return MagicMock(spec=Driver) + return MagicMock(spec=neo4j.Driver) @pytest.fixture(scope="function") diff --git a/tests/unit/test_indexes.py b/tests/unit/test_indexes.py index 84122684..c5509da9 100644 --- a/tests/unit/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import neo4j.exceptions import pytest from neo4j_genai.indexes import ( create_vector_index, - drop_index, + drop_index_if_exists, create_fulltext_index, ) @@ -68,16 +68,33 @@ def test_create_vector_index_validation_error_dimensions(driver): assert "Error for inputs to create_vector_index" in str(excinfo) +def test_create_vector_index_raises_error_with_neo4j_client_error(driver): + driver.execute_query.side_effect = neo4j.exceptions.ClientError + with pytest.raises(neo4j.exceptions.ClientError): + create_vector_index(driver, "my-index", "People", "name", 2048, "cosine") + + def test_create_vector_index_validation_error_similarity_fn(driver): with pytest.raises(ValueError) as excinfo: create_vector_index(driver, "my-index", "People", "name", 1536, "algebra") assert "Error for inputs to create_vector_index" in str(excinfo) -def test_drop_index(driver): +def test_drop_index_if_exists(driver): drop_query = "DROP INDEX $name IF EXISTS" - drop_index(driver, "my-index") + drop_index_if_exists(driver, "my-index") + + driver.execute_query.assert_called_once_with( + drop_query, + {"name": "my-index"}, + ) + + +def test_drop_index_if_exists_raises_error_with_neo4j_client_error(driver): + drop_query = "DROP INDEX $name IF EXISTS" + + drop_index_if_exists(driver, "my-index") driver.execute_query.assert_called_once_with( drop_query, @@ -102,6 +119,15 @@ def test_create_fulltext_index_happy_path(driver): ) +def test_create_fulltext_index_raises_error_with_neo4j_client_error(driver): + label = "node-label" + text_node_properties = ["property-1", "property-2"] + driver.execute_query.side_effect = neo4j.exceptions.ClientError + + with pytest.raises(neo4j.exceptions.ClientError): + create_fulltext_index(driver, "my-index", label, text_node_properties) + + def test_create_fulltext_index_empty_node_properties(driver): label = "node-label" node_properties = [] From 3a9de9cfd0cf4f63c9e97170e081e30bcd0f8035 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 17:34:50 +0200 Subject: [PATCH 18/38] Support for filters in Vector and VectorCypher retrievers --- src/neo4j_genai/neo4j_queries.py | 112 ++++++--- src/neo4j_genai/retrievers/base.py | 16 ++ src/neo4j_genai/retrievers/filters.py | 315 ++++++++++++++++++++++++++ src/neo4j_genai/retrievers/hybrid.py | 4 +- src/neo4j_genai/retrievers/vector.py | 34 ++- src/neo4j_genai/types.py | 2 +- tests/e2e/conftest.py | 7 +- tests/e2e/test_vector_e2e.py | 21 ++ tests/unit/retrievers/test_filters.py | 162 +++++++++++++ tests/unit/retrievers/test_hybrid.py | 8 +- tests/unit/retrievers/test_vector.py | 45 ++-- tests/unit/test_neo4j_queries.py | 71 ++---- 12 files changed, 673 insertions(+), 124 deletions(-) create mode 100644 src/neo4j_genai/retrievers/filters.py create mode 100644 tests/unit/retrievers/test_filters.py diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index b9ab366a..e3cf4149 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -12,47 +12,93 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Any from neo4j_genai.types import SearchType +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +VECTOR_INDEX_QUERY = ( + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score" +) + +VECTOR_EXACT_QUERY = ( + "WITH node, " + "vector.similarity.cosine(node.`{embedding_node_property}`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" +) + +BASE_VECTOR_EXACT_QUERY = ( + "MATCH (node:`{node_label}`) " + "WHERE node.`{embedding_node_property}` IS NOT NULL " + "AND size(node.`{embedding_node_property}`) = toInteger($embedding_dimension)" +) + +FULL_TEXT_SEARCH_QUERY = ( + "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " + "YIELD node, score" +) + + +def _get_hybrid_query() -> str: + return ( + f"CALL {{ {VECTOR_INDEX_QUERY} " + f"RETURN node, score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS max " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / max) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + + +def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + base_query = BASE_VECTOR_EXACT_QUERY.format( + node_label=node_label, + embedding_node_property=embedding_node_property, + ) + vector_query = VECTOR_EXACT_QUERY.format( + embedding_node_property=embedding_node_property, + ) + query_params["embedding_dimension"] = embedding_dimension + return f"""{base_query} + AND ({where_filters}) + {vector_query} + """, query_params + + +def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + if filters: + return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return VECTOR_INDEX_QUERY, {} def get_search_query( search_type: SearchType, return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, -): - query_map = { - SearchType.VECTOR: "".join( - [ - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) ", - "YIELD node, score ", - get_query_tail(retrieval_query, return_properties), - ] - ), - SearchType.HYBRID: "".join( - [ - "CALL { ", - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) ", - "YIELD node, score ", - "RETURN node, score UNION ", - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) ", - "YIELD node, score ", - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max ", - "UNWIND nodes AS n ", - "RETURN n.node AS node, (n.score / max) AS score ", - "} ", - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k ", - get_query_tail( - retrieval_query, return_properties, "RETURN node, score" - ), - ] - ), - } - return query_map[search_type] - - -def get_query_tail( + node_label: Optional[str] = None, + embedding_node_property: Optional[str] = None, + embedding_dimension: Optional[int] = None, + filters: Optional[dict[str, Any]] = None, +) -> tuple[str, dict[str, Any]]: + if search_type == SearchType.HYBRID: + if filters: + raise Exception("Filters is not supported with Hybrid Search") + query = _get_hybrid_query() + params = {} + elif search_type == SearchType.VECTOR: + query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + else: + raise ValueError(f"Search type is not supported: {search_type}") + query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + return " ".join([query, query_tail]), params + + +def _get_query_tail( retrieval_query: Optional[str] = None, return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index c3ca671d..e478e3a5 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -57,3 +57,19 @@ def _verify_version(self) -> None: @abstractmethod def search(self, *args, **kwargs) -> Any: pass + + def _fetch_index_infos(self): + """Fetch the node label and embedding property from the index definition""" + query = """SHOW VECTOR INDEXES +YIELD name, labelsOrTypes, properties, options +WHERE name = $index_name +RETURN labelsOrTypes as labels, properties, options.indexConfig.`vector.dimensions` as dimensions + """ + result = self.driver.execute_query(query, {"index_name": self.index_name}) + try: + result = result.records[0] + except IndexError: + raise Exception(f"No index with name {self.index_name} found") + self._node_label = result["labels"][0] + self._embedding_node_property = result["properties"][0] + self._embedding_dimension = result["dimensions"] diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py new file mode 100644 index 00000000..358a92fc --- /dev/null +++ b/src/neo4j_genai/retrievers/filters.py @@ -0,0 +1,315 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Filters format: +{"property_name": "property_value"} + + +""" +from typing import Any, Type +from collections import Counter + + +DEFAULT_NODE_ALIAS = "node" + + +class Operator: + """Operator classes are helper classes to build the Cypher queries + from a filter like {"field_name": "field_value"} + They implement two important methods: + - lhs: (left hand side): the node + property to be filtered on + + optional operations on it (see ILikeOperator for instance) + - cleaned_value: a method to make sure the provided parameter values are + consistent with the operator (e.g. LIKE operator only works with string values) + """ + CYPHER_OPERATOR = None + + def __init__(self, node_alias=DEFAULT_NODE_ALIAS): + self.node_alias = node_alias + + def lhs(self, field): + return f"{self.node_alias}.`{field}`" + + def cleaned_value(self, value): + return value + + +class EqOperator(Operator): + CYPHER_OPERATOR = "=" + + +class NeqOperator(Operator): + CYPHER_OPERATOR = "<>" + + +class LtOperator(Operator): + CYPHER_OPERATOR = "<" + + +class GtOperator(Operator): + CYPHER_OPERATOR = ">" + + +class LteOperator(Operator): + CYPHER_OPERATOR = "<=" + + +class GteOperator(Operator): + CYPHER_OPERATOR = ">=" + + +class InOperator(Operator): + CYPHER_OPERATOR = "IN" + + def cleaned_value(self, value): + for val in value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + return value + + +class NinOperator(InOperator): + CYPHER_OPERATOR = "NOT IN" + + +class LikeOperator(Operator): + CYPHER_OPERATOR = "CONTAINS" + + def cleaned_value(self, value): + if not isinstance(value, str): + raise ValueError(f"Expected string value, got {type(value)}: {value}") + return value.rstrip("%") + + +class ILikeOperator(LikeOperator): + + def lhs(self, field): + return f"toLower({self.node_alias}.`{field}`)" + + def cleaned_value(self, value): + value = super().cleaned_value(value) + return value.lower() + + +OPERATOR_PREFIX = "$" + +OPERATOR_EQ = "$eq" +OPERATOR_NE = "$ne" +OPERATOR_LT = "$lt" +OPERATOR_LTE = "$lte" +OPERATOR_GT = "$gt" +OPERATOR_GTE = "$gte" +OPERATOR_BETWEEN = "$between" +OPERATOR_IN = "$in" +OPERATOR_NIN = "$nin" +OPERATOR_LIKE = "$like" +OPERATOR_ILIKE = "$ilike" + +OPERATOR_AND = "$and" +OPERATOR_OR = "$or" + +COMPARISONS_TO_NATIVE = { + OPERATOR_EQ: EqOperator, + OPERATOR_NE: NeqOperator, + OPERATOR_LT: LtOperator, + OPERATOR_LTE: LteOperator, + OPERATOR_GT: GtOperator, + OPERATOR_GTE: GteOperator, + OPERATOR_IN: InOperator, + OPERATOR_NIN: NinOperator, + OPERATOR_LIKE: LikeOperator, + OPERATOR_ILIKE: ILikeOperator, +} + + +LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(LOGICAL_OPERATORS) + .union({OPERATOR_BETWEEN}) +) + + +class ParameterStore: + """ + Store parameters for a given query. + Determine the parameter name depending on a parameter counter + """ + + def __init__(self): + self._counter = Counter() + self.params = {} + + def _get_params_name(self, key="param"): + """NB: the counter parameter is there in purpose, will be modified in the function + to remember the count of each parameter + + :param p: + :param counter: + :return: + """ + # key = slugify(key.replace(".", "_"), separator="_") + param_name = f"{key}_{self._counter[key]}" + self._counter[key] += 1 + return param_name + + def add(self, key, value): + param_name = self._get_params_name() + self.params[param_name] = value + return param_name + + +def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: + """Return Cypher for field operator value + NB: the param_store argument is mutable, it will be updated in this function + """ + native_op = native_operator_class() + param_name = param_store.add(field, native_op.cleaned_value(value)) + query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" + return query_snippet + + +def _handle_field_filter( + field: str, value: Any, param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS +) -> str: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_store: + node_alias: + + Returns + - Cypher filter snippet* + + NB: the param_store argument is mutable, it will be updated in this function + """ + # first, perform some sanity checks + if not isinstance(field, str): + raise ValueError( + f"Field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith(OPERATOR_PREFIX): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") + + if isinstance(value, dict): + # This is a filter specification e.g. {"$gte": 0} + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + operator = operator.lower() + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # if value is not dict, then we assume an equality operator + operator = OPERATOR_EQ + filter_value = value + + # now everything is set, we can start and build the query + # special case for the BETWEEN operator that requires + # two tests (lower_bound <= value <= higher_bound) + if operator == OPERATOR_BETWEEN: + low, high = filter_value + param_name_low = param_store.add(field, low) + param_name_high = param_store.add(field, high) + query_snippet = ( + f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" + ) + return query_snippet + # all the other operators are handled through their own classes: + native_op_class = COMPARISONS_TO_NATIVE[operator] + return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + + +def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: + """Construct a metadata filter. This is a recursive function parsing the filter dict + + Args: + filter: A dictionary representing the filter condition. + param_store: A ParamStore object that will deal with parameter naming and saving along the process + node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + + Returns: + str + + NB: the param_store argument is mutable, it will be updated in this function + """ + + if not isinstance(filter, dict): + raise ValueError() + # if we have more than one entry, this is an implicit "AND" filter + if len(filter) > 1: + return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if not key.startswith("$"): + # it's not an operator, must be a field + return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == OPERATOR_AND: + cypher_operator = " AND " + elif key.lower() == OPERATOR_OR: + cypher_operator = " OR " + else: + raise ValueError(f"Unsupported filter {filter}") + query = cypher_operator.join( + [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + ) + return query + + +def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: + """Construct the cypher filter snippet based on a filter dict + + Args: + filter: a dict of filters + node_alias: the node the filters must be applied on + + Return: + A tuple of str, dict where the string is the cypher query and the dict + contains the query parameters + """ + param_store = ParameterStore() + return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index a2163b34..fea96a2d 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -84,7 +84,7 @@ def search( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query = get_search_query(SearchType.HYBRID, self.return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties) logger.debug("HybridRetriever Cypher parameters: %s", parameters) logger.debug("HybridRetriever Cypher query: %s", search_query) @@ -160,7 +160,7 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( + search_query, _ = get_search_query( SearchType.HYBRID, retrieval_query=self.retrieval_query ) diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index af3a6068..32314801 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -48,12 +48,17 @@ def __init__( self.index_name = index_name self.return_properties = return_properties self.embedder = embedder + self._node_label = None + self._embedding_node_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, + filters: Optional[dict[str, Any]] = None, ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -75,7 +80,7 @@ def search( """ try: validated_data = VectorSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -93,7 +98,15 @@ def search( parameters["query_vector"] = query_vector del parameters["query_text"] - search_query = get_search_query(SearchType.VECTOR, self.return_properties) + search_query, search_params = get_search_query( + SearchType.VECTOR, + self.return_properties, + node_label=self._node_label, + embedding_node_property=self._embedding_node_property, + embedding_dimension=self._embedding_dimension, + filters=filters, + ) + parameters.update(search_params) logger.debug("VectorRetriever Cypher parameters: %s", parameters) logger.debug("VectorRetriever Cypher query: %s", search_query) @@ -129,6 +142,10 @@ def __init__( self.index_name = index_name self.retrieval_query = retrieval_query self.embedder = embedder + self._node_label = None + self._node_embedding_property = None + self._embedding_dimension = None + self._fetch_index_infos() def search( self, @@ -136,6 +153,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, + filters: Optional[dict[str, Any]] = None, ) -> list[neo4j.Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -158,7 +176,7 @@ def search( """ try: validated_data = VectorCypherSearchModel( - index_name=self.index_name, + vector_index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, @@ -181,9 +199,15 @@ def search( parameters[key] = value del parameters["query_params"] - search_query = get_search_query( - SearchType.VECTOR, retrieval_query=self.retrieval_query + search_query, search_params = get_search_query( + SearchType.VECTOR, + retrieval_query=self.retrieval_query, + node_label=self._node_label, + embedding_node_property=self._node_embedding_property, + embedding_dimension=self._embedding_dimension, + filters=filters, ) + parameters.update(search_params) logger.debug("VectorCypherRetriever Cypher parameters: %s", parameters) logger.debug("VectorCypherRetriever Cypher query: %s", search_query) diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index fc9a3e5a..357bb44e 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -54,7 +54,7 @@ def check_node_properties_not_empty(cls, v): class VectorSearchModel(BaseModel): - index_name: str + vector_index_name: str top_k: PositiveInt = 5 query_vector: Optional[list[float]] = None query_text: Optional[str] = None diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 0ce0c2c8..4df466e9 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -60,7 +60,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=1536, + dimensions=10, similarity_fn="euclidean", ) @@ -70,7 +70,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(1536)] + vector = [random.random() for _ in range(10)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -78,6 +78,8 @@ def random_str(n: int) -> str: for i in range(10): insert_query = ( "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $i, " + " doc.short_text_property = toString($i)" "WITH doc " "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" "WITH doc " @@ -88,6 +90,7 @@ def random_str(n: int) -> str: parameters = { "id": str(uuid.uuid4()), + "i": i, "vector": vector, "authorName": random_str(10), } diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index 9bf3f5a4..baeae191 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -102,3 +102,24 @@ def test_vector_retriever_return_properties(driver): assert len(results) == 5 for result in results: assert isinstance(result, VectorSearchRecord) + + +@pytest.mark.usefixtures("setup_neo4j") +def test_vector_retriever_filters(driver): + retriever = VectorRetriever( + driver, + "vector-index-name", + ) + + top_k = 2 + results = retriever.search( + query_vector=[1.0 for _ in range(10)], + filters={"int_property": {"$gt": 2}}, + top_k=top_k, + ) + + assert isinstance(results, list) + assert len(results) == 2 + for result in results: + assert isinstance(result, VectorSearchRecord) + assert result.node["int_property"] > 2 diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py new file mode 100644 index 00000000..536f9491 --- /dev/null +++ b/tests/unit/retrievers/test_filters.py @@ -0,0 +1,162 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from neo4j_genai.retrievers.filters import construct_metadata_filter + + +def test_filter_single_field_string(): + filters = {"field": "string_value"} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_single_field_int(): + filters = {"field": 28} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": 28} + + +def test_filter_single_field_bool(): + filters = {"field": False} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": False} + + +def test_filter_explicit_eq_operator(): + filters = {"field": {"$eq": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` = $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_neq_operator(): + filters = {"field": {"$ne": "string_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <> $param_0" + assert params == {"param_0": "string_value"} + + +def test_filter_lt_operator(): + filters = {"field": {"$lt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` < $param_0" + assert params == {"param_0": 1} + + +def test_filter_gt_operator(): + filters = {"field": {"$gt": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` > $param_0" + assert params == {"param_0": 1} + + +def test_filter_lte_operator(): + filters = {"field": {"$lte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` <= $param_0" + assert params == {"param_0": 1} + + +def test_filter_gte_operator(): + filters = {"field": {"$gte": 1}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` >= $param_0" + assert params == {"param_0": 1} + + +def test_filter_in_operator(): + filters = {"field": {"$in": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_not_in_operator(): + filters = {"field": {"$nin": ["a", "b"]}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` NOT IN $param_0" + assert params == {"param_0": ["a", "b"]} + + +def test_filter_like_operator(): + filters = {"field": {"$like": "some_value"}} + query, params = construct_metadata_filter(filters) + assert query == "node.`field` CONTAINS $param_0" + assert params == {"param_0": "some_value"} + + +def test_filter_ilike_operator(): + filters = {"field": {"$ilike": "Some Value"}} + query, params = construct_metadata_filter(filters) + assert query == "toLower(node.`field`) CONTAINS $param_0" + assert params == {"param_0": "some value"} + + +def test_filter_between_operator(): + filters = {"field": {"$between": [0, 1]}} + query, params = construct_metadata_filter(filters) + assert query == "$param_0 <= node.`field` <= $param_1" + assert params == {"param_0": 0, "param_1": 1} + + +def test_filter_implicit_and_condition(): + filters = {"field_1": "string_value", "field_2": True} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_explicit_and_condition(): + filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_or_condition(): + filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} + query, params = construct_metadata_filter(filters) + assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" + assert params == {"param_0": "string_value", "param_1": True} + + +def test_filter_and_or_combined(): + filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} + query, params = construct_metadata_filter(filters) + assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} + + +# now testing bad filters +def test_field_name_with_dollar_sign(): + filters = {"$field": "value"} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_and_no_list(): + filters = {"$and": {}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) + + +def test_unsupported_operator(): + filters = {"field": {"$unsupported": "value"}} + with pytest.raises(ValueError): + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index b55e3c54..093364a6 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -60,7 +60,7 @@ def test_hybrid_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) records = retriever.search(query_text=query_text, top_k=top_k) @@ -98,7 +98,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( None, None, ] - search_query = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query(SearchType.HYBRID) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -161,7 +161,7 @@ def test_hybrid_retriever_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.HYBRID, return_properties) + search_query, _ = get_search_query(SearchType.HYBRID, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -206,7 +206,7 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 69c1f615..9be9da60 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -34,8 +34,9 @@ def test_vector_cypher_retriever_initialization(driver): mock_verify.assert_called_once() +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, driver): +def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -46,14 +47,14 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_vector=query_vector, top_k=top_k) retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, @@ -61,8 +62,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, driver): +def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -75,7 +77,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) records = retriever.search(query_text=query_text, top_k=top_k) @@ -83,7 +85,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -92,8 +94,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, driver): +def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -111,7 +114,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, return_properties) + search_query, _ = get_search_query(SearchType.VECTOR, return_properties) records = retriever.search(query_text=query_text, top_k=top_k) @@ -119,7 +122,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query.rstrip(), { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -175,8 +178,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri ) +@patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, driver): +def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -187,7 +191,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR) + search_query, _ = get_search_query(SearchType.VECTOR) with pytest.raises(ValueError): retriever.search(query_vector=query_vector, top_k=top_k) @@ -195,15 +199,16 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): retriever.driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": query_vector, }, ) +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_happy_path(_verify_version_mock, driver): +def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -221,7 +226,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -232,7 +237,7 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, }, @@ -240,8 +245,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_with_params(_verify_version_mock, driver): +def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -265,7 +271,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): None, None, ] - search_query = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) records = retriever.search( query_text=query_text, @@ -278,7 +284,7 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): driver.execute_query.assert_called_once_with( search_query, { - "index_name": index_name, + "vector_index_name": index_name, "top_k": top_k, "query_vector": embed_query_vector, "param": "dummy-param", @@ -288,8 +294,9 @@ def test_retrieval_query_with_params(_verify_version_mock, driver): assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}] +@patch("neo4j_genai.VectorCypherRetriever._fetch_index_infos") @patch("neo4j_genai.VectorCypherRetriever._verify_version") -def test_retrieval_query_cypher_error(_verify_version_mock, driver): +def test_retrieval_query_cypher_error(_verify_version_mock, _fetch_index_infos, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 3ce7c774..d20185b2 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query, get_query_tail +from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType def test_vector_search_basic(): expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " - "YIELD node, score" + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " + "YIELD node, score " + "RETURN node, score" ) - result = get_search_query(SearchType.VECTOR) + result, params = get_search_query(SearchType.VECTOR) assert result.strip() == expected.strip() + assert params == {} def test_hybrid_search_basic(): @@ -41,28 +43,28 @@ def test_hybrid_search_basic(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node, score" ) - result = get_search_query(SearchType.HYBRID) + result, _ = get_search_query(SearchType.HYBRID) assert result.strip() == expected.strip() def test_vector_search_with_properties(): properties = ["name", "age"] expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.VECTOR, return_properties=properties) + result, _ = get_search_query(SearchType.VECTOR, return_properties=properties) assert result.strip() == expected.strip() def test_vector_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( - "CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) " + "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " "YIELD node, score " + retrieval_query ) - result = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -82,7 +84,7 @@ def test_hybrid_search_with_retrieval_query(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " + retrieval_query ) - result = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) assert result.strip() == expected.strip() @@ -102,52 +104,5 @@ def test_hybrid_search_with_properties(): "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " "RETURN node {.name, .age} as node, score" ) - result = get_search_query(SearchType.HYBRID, return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_retrieval_query(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = retrieval_query - result = get_query_tail(retrieval_query=retrieval_query) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_properties(): - properties = ["name", "age"] - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail(return_properties=properties) - assert result.strip() == expected.strip() - - -def test_get_query_tail_with_fallback(): - fallback = "HELLO" - expected = fallback - result = get_query_tail(fallback_return=fallback) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_all(): - retrieval_query = "MATCH (n) RETURN n LIMIT 10" - properties = ["name", "age"] - fallback = "HELLO" - - expected = retrieval_query - result = get_query_tail( - retrieval_query=retrieval_query, - return_properties=properties, - fallback_return=fallback, - ) - assert result.strip() == expected.strip() - - -def test_get_query_tail_ordering_no_retrieval_query(): - properties = ["name", "age"] - fallback = "HELLO" - - expected = "RETURN node {.name, .age} as node, score" - result = get_query_tail( - return_properties=properties, - fallback_return=fallback, - ) + result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() From 6e90039ad48d8c60a23c84ddb5b10190d9e768fe Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:00:38 +0200 Subject: [PATCH 19/38] Ruff --- src/neo4j_genai/neo4j_queries.py | 33 +++++++++++---- src/neo4j_genai/retrievers/filters.py | 59 ++++++++++++++++----------- tests/unit/retrievers/test_filters.py | 20 ++++++--- tests/unit/retrievers/test_hybrid.py | 4 +- tests/unit/retrievers/test_vector.py | 24 ++++++++--- 5 files changed, 97 insertions(+), 43 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index e3cf4149..5897fa5a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -54,7 +54,12 @@ def _get_hybrid_query() -> str: ) -def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_filtered_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -64,15 +69,25 @@ def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embeddi embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return f"""{base_query} + return ( + f"""{base_query} AND ({where_filters}) {vector_query} - """, query_params + """, + query_params, + ) -def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: if filters: - return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return _get_filtered_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) return VECTOR_INDEX_QUERY, {} @@ -91,10 +106,14 @@ def get_search_query( query = _get_hybrid_query() params = {} elif search_type == SearchType.VECTOR: - query, params = _get_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + query, params = _get_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) else: raise ValueError(f"Search type is not supported: {search_type}") - query_tail = _get_query_tail(retrieval_query, return_properties, fallback_return="RETURN node, score") + query_tail = _get_query_tail( + retrieval_query, return_properties, fallback_return="RETURN node, score" + ) return " ".join([query, query_tail]), params diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 358a92fc..0919c237 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -12,12 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Filters format: -{"property_name": "property_value"} - - -""" from typing import Any, Type from collections import Counter @@ -34,6 +28,7 @@ class Operator: - cleaned_value: a method to make sure the provided parameter values are consistent with the operator (e.g. LIKE operator only works with string values) """ + CYPHER_OPERATOR = None def __init__(self, node_alias=DEFAULT_NODE_ALIAS): @@ -96,7 +91,6 @@ def cleaned_value(self, value): class ILikeOperator(LikeOperator): - def lhs(self, field): return f"toLower({self.node_alias}.`{field}`)" @@ -139,9 +133,7 @@ def cleaned_value(self, value): LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(LOGICAL_OPERATORS) - .union({OPERATOR_BETWEEN}) + set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) ) @@ -174,7 +166,13 @@ def add(self, key, value): return param_name -def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: +def _single_condition_cypher( + field: str, + native_operator_class: Type[Operator], + value: Any, + param_store: ParameterStore, + node_alias: str, +) -> str: """Return Cypher for field operator value NB: the param_store argument is mutable, it will be updated in this function """ @@ -185,8 +183,10 @@ def _single_condition_cypher(field: str, native_operator_class: Type[Operator], def _handle_field_filter( - field: str, value: Any, param_store: ParameterStore, - node_alias: str = DEFAULT_NODE_ALIAS + field: str, + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Create a filter for a specific field. @@ -254,10 +254,14 @@ def _handle_field_filter( return query_snippet # all the other operators are handled through their own classes: native_op_class = COMPARISONS_TO_NATIVE[operator] - return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + return _single_condition_cypher( + field, native_op_class, filter_value, param_store, node_alias + ) -def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: +def _construct_metadata_filter( + filter: dict[str, Any], param_store: ParameterStore, node_alias: str +) -> str: """Construct a metadata filter. This is a recursive function parsing the filter dict Args: @@ -275,19 +279,21 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto raise ValueError() # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: - return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + return _construct_metadata_filter( + {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias + ) # The only operators allowed at the top level are $AND and $OR # First check if an operator or a field key, value = list(filter.items())[0] if not key.startswith("$"): # it's not an operator, must be a field - return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + return _handle_field_filter( + key, filter[key], param_store, node_alias=node_alias + ) # Here we handle the $and and $or operators if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") if key.lower() == OPERATOR_AND: cypher_operator = " AND " elif key.lower() == OPERATOR_OR: @@ -295,12 +301,17 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto else: raise ValueError(f"Unsupported filter {filter}") query = cypher_operator.join( - [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + [ + f"({ _construct_metadata_filter(el, param_store, node_alias)})" + for el in value + ] ) return query -def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: +def construct_metadata_filter( + filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS +) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: @@ -312,4 +323,6 @@ def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_ contains the query parameters """ param_store = ParameterStore() - return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params + return _construct_metadata_filter( + filter, param_store, node_alias=node_alias + ), param_store.params diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index 536f9491..fd562118 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -137,9 +137,17 @@ def test_filter_or_condition(): def test_filter_and_or_combined(): - filters = {"$and": [{"$or": [{"field_1": "string_value"}, {"field_2": True}]}, {"field_3": 11}]} - query, params = construct_metadata_filter(filters) - assert query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + filters = { + "$and": [ + {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, + {"field_3": 11}, + ] + } + query, params = construct_metadata_filter(filters) + assert ( + query + == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} @@ -147,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + construct_metadata_filter(filters) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 093364a6..79486835 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -206,7 +206,9 @@ def test_hybrid_cypher_retrieval_query_with_params(_verify_version_mock, driver) None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.HYBRID, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 9be9da60..c3fd1ade 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -36,7 +36,9 @@ def test_vector_cypher_retriever_initialization(driver): @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -64,7 +66,9 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, _fetch_index_ @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_happy_path( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -96,7 +100,9 @@ def test_similarity_search_text_happy_path(_verify_version_mock, _fetch_index_in @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_text_return_properties(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_text_return_properties( + _verify_version_mock, _fetch_index_infos, driver +): embed_query_vector = [1.0 for _ in range(3)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector @@ -180,7 +186,9 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri @patch("neo4j_genai.VectorRetriever._fetch_index_infos") @patch("neo4j_genai.VectorRetriever._verify_version") -def test_similarity_search_vector_bad_results(_verify_version_mock, _fetch_index_infos, driver): +def test_similarity_search_vector_bad_results( + _verify_version_mock, _fetch_index_infos, driver +): index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] @@ -226,7 +234,9 @@ def test_retrieval_query_happy_path(_verify_version_mock, _fetch_index_infos, dr None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, @@ -271,7 +281,9 @@ def test_retrieval_query_with_params(_verify_version_mock, _fetch_index_infos, d None, None, ] - search_query, _ = get_search_query(SearchType.VECTOR, retrieval_query=retrieval_query) + search_query, _ = get_search_query( + SearchType.VECTOR, retrieval_query=retrieval_query + ) records = retriever.search( query_text=query_text, From d3106af2a3283d34cc90492509ae348e766423b6 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:04:56 +0200 Subject: [PATCH 20/38] Back to the normal dimension size in e2e tests --- tests/e2e/conftest.py | 6 +++--- tests/e2e/test_vector_e2e.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 4df466e9..5e5f3f97 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -60,7 +60,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=10, + dimensions=1536, similarity_fn="euclidean", ) @@ -70,7 +70,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(10)] + vector = [random.random() for _ in range(1536)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) @@ -92,6 +92,6 @@ def random_str(n: int) -> str: "id": str(uuid.uuid4()), "i": i, "vector": vector, - "authorName": random_str(10), + "authorName": random_str(1536), } driver.execute_query(insert_query, parameters) diff --git a/tests/e2e/test_vector_e2e.py b/tests/e2e/test_vector_e2e.py index baeae191..608dd4d0 100644 --- a/tests/e2e/test_vector_e2e.py +++ b/tests/e2e/test_vector_e2e.py @@ -113,7 +113,7 @@ def test_vector_retriever_filters(driver): top_k = 2 results = retriever.search( - query_vector=[1.0 for _ in range(10)], + query_vector=[1.0 for _ in range(1536)], filters={"int_property": {"$gt": 2}}, top_k=top_k, ) From 53d622f5785a0fa5c4fc9afef84bd88545b96a0a Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:21:49 +0200 Subject: [PATCH 21/38] Improved docstrings + include an example --- examples/vector_search_with_filters.py | 72 ++++++++++++++++++++++++++ src/neo4j_genai/neo4j_queries.py | 59 +++++++++++++++++++-- src/neo4j_genai/retrievers/filters.py | 37 +++++++++---- 3 files changed, 155 insertions(+), 13 deletions(-) create mode 100644 examples/vector_search_with_filters.py diff --git a/examples/vector_search_with_filters.py b/examples/vector_search_with_filters.py new file mode 100644 index 00000000..bf5fa444 --- /dev/null +++ b/examples/vector_search_with_filters.py @@ -0,0 +1,72 @@ +from neo4j import GraphDatabase +from neo4j_genai import VectorRetriever + +import random +import string +from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index + + +URI = "neo4j://localhost:7687" +AUTH = ("neo4j", "password") + +INDEX_NAME = "embedding-name" +DIMENSION = 1536 + +# Connect to Neo4j database +driver = GraphDatabase.driver(URI, auth=AUTH) + + +# Create Embedder object +class CustomEmbedder(Embedder): + def embed_query(self, text: str) -> list[float]: + return [random.random() for _ in range(DIMENSION)] + + +# Generate random strings +def random_str(n: int) -> str: + return "".join([random.choice(string.ascii_letters) for _ in range(n)]) + + +embedder = CustomEmbedder() + +# Creating the index +create_vector_index( + driver, + INDEX_NAME, + label="Document", + property="propertyKey", + dimensions=DIMENSION, + similarity_fn="euclidean", +) + +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME, embedder) + +# Upsert the query +vector = [random.random() for _ in range(DIMENSION)] +insert_query = ( + "MERGE (doc:Document {id: $id})" + "ON CREATE SET doc.int_property = $id, " + " doc.short_text_property = toString($id)" + "WITH doc " + "CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)" + "WITH doc " + "MERGE (author:Author {name: $authorName})" + "MERGE (doc)-[:AUTHORED_BY]->(author)" + "RETURN doc, author" +) +parameters = { + "id": random.randint(0, 10000), + "vector": vector, + "authorName": random_str(10), +} +driver.execute_query(insert_query, parameters) + +# Perform the search +query_text = "Find me a book about Fremen" +print( + retriever.search( + query_text=query_text, top_k=1, filters={"int_property": {"$gt": 100}} + ) +) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 5897fa5a..52bbf332 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -60,6 +60,18 @@ def _get_filtered_vector_query( embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build Cypher query for vector search with filters + Uses exact KNN. + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + """ where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -71,19 +83,31 @@ def _get_filtered_vector_query( query_params["embedding_dimension"] = embedding_dimension return ( f"""{base_query} - AND ({where_filters}) - {vector_query} + AND ({where_filters}) + {vector_query} """, query_params, ) def _get_vector_query( - filters: dict[str, Any], + filters: Optional[dict[str, Any]], node_label: str, embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build the vector query with or without filters + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if filters: return _get_filtered_vector_query( filters, node_label, embedding_node_property, embedding_dimension @@ -100,6 +124,23 @@ def get_search_query( embedding_dimension: Optional[int] = None, filters: Optional[dict[str, Any]] = None, ) -> tuple[str, dict[str, Any]]: + """Build the search query, including pre-filtering if needed, and return clause. + + Args + search_type: Search type we want to search for: + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if search_type == SearchType.HYBRID: if filters: raise Exception("Filters is not supported with Hybrid Search") @@ -122,6 +163,18 @@ def _get_query_tail( return_properties: Optional[list[str]] = None, fallback_return: Optional[str] = None, ) -> str: + """Build the RETURN statement after the search is performed + + Args + return_properties (list[str]): list of property names to return. + It can't be provided together with retrieval_query. + retrieval_query (str): the query to use to retrieve the search results + It can't be provided together with return_properties. + fallback_return (str): the fallback return statement to use to retrieve the search results + + Returns: + str: the RETURN statement + """ if retrieval_query: return retrieval_query if return_properties: diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 0919c237..fc052fa7 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -148,8 +148,11 @@ def __init__(self): self.params = {} def _get_params_name(self, key="param"): - """NB: the counter parameter is there in purpose, will be modified in the function - to remember the count of each parameter + """Find parameter name so that param names are unique. + This function adds a suffix to the key corresponding to the number + of times the key have been used in the query. + E.g. + node.age >= $param_0 AND node.age <= $param_1 :param p: :param counter: @@ -161,6 +164,9 @@ def _get_params_name(self, key="param"): return param_name def add(self, key, value): + """This function adds a new parameter to the param dict. + It returns the name of the parameter to be used as a placeholder + in the cypher query, e.g. $param_0""" param_name = self._get_params_name() self.params[param_name] = value return param_name @@ -173,10 +179,21 @@ def _single_condition_cypher( param_store: ParameterStore, node_alias: str, ) -> str: - """Return Cypher for field operator value + """Return Cypher for field operator value. + + Args: + field: the name of the field being filtered + native_operator_class: the operator class that will be used to generate + the Cypher query + value: filtered value + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query + Returns: + str: the Cypher condition, e.g. node.`property` = $param_0 + NB: the param_store argument is mutable, it will be updated in this function """ - native_op = native_operator_class() + native_op = native_operator_class(node_alias=node_alias) param_name = param_store.add(field, native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -196,11 +213,11 @@ def _handle_field_filter( If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by - param_store: - node_alias: + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns - - Cypher filter snippet* + str: Cypher filter snippet NB: the param_store argument is mutable, it will be updated in this function """ @@ -266,11 +283,11 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. - param_store: A ParamStore object that will deal with parameter naming and saving along the process - node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns: - str + str: the Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ From c9bd3db220bac7dbf24e945f7e5686dd0a87e8c5 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 09:20:33 +0200 Subject: [PATCH 22/38] Re-add tests for the _get_query_tail function (deleted by mistake) --- tests/unit/test_neo4j_queries.py | 49 +++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index d20185b2..0d420c51 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j_genai.neo4j_queries import get_search_query +from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType @@ -106,3 +106,50 @@ def test_hybrid_search_with_properties(): ) result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) assert result.strip() == expected.strip() + + +def test_get_query_tail_with_retrieval_query(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + expected = retrieval_query + result = _get_query_tail(retrieval_query=retrieval_query) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_properties(): + properties = ["name", "age"] + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail(return_properties=properties) + assert result.strip() == expected.strip() + + +def test_get_query_tail_with_fallback(): + fallback = "HELLO" + expected = fallback + result = _get_query_tail(fallback_return=fallback) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_all(): + retrieval_query = "MATCH (n) RETURN n LIMIT 10" + properties = ["name", "age"] + fallback = "HELLO" + + expected = retrieval_query + result = _get_query_tail( + retrieval_query=retrieval_query, + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() + + +def test_get_query_tail_ordering_no_retrieval_query(): + properties = ["name", "age"] + fallback = "HELLO" + + expected = "RETURN node {.name, .age} as node, score" + result = _get_query_tail( + return_properties=properties, + fallback_return=fallback, + ) + assert result.strip() == expected.strip() From 7851f2d5fb5685282027df3123d44258e94f0682 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 13:26:35 +0200 Subject: [PATCH 23/38] Update docstrings, move filters file, rename function --- src/neo4j_genai/{retrievers => }/filters.py | 34 ++++++++-------- src/neo4j_genai/neo4j_queries.py | 4 +- tests/unit/retrievers/test_filters.py | 44 ++++++++++----------- 3 files changed, 42 insertions(+), 40 deletions(-) rename src/neo4j_genai/{retrievers => }/filters.py (92%) diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/filters.py similarity index 92% rename from src/neo4j_genai/retrievers/filters.py rename to src/neo4j_genai/filters.py index fc052fa7..16699cc2 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/filters.py @@ -154,9 +154,10 @@ def _get_params_name(self, key="param"): E.g. node.age >= $param_0 AND node.age <= $param_1 - :param p: - :param counter: - :return: + Args: + key (str): The prefix for the parameter name + Returns: + The full unique parameter name """ # key = slugify(key.replace(".", "_"), separator="_") param_name = f"{key}_{self._counter[key]}" @@ -182,14 +183,14 @@ def _single_condition_cypher( """Return Cypher for field operator value. Args: - field: the name of the field being filtered - native_operator_class: the operator class that will be used to generate + field: The name of the field being filtered + native_operator_class: The operator class that will be used to generate the Cypher query value: filtered value param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher condition, e.g. node.`property` = $param_0 + str: The Cypher condition, e.g. node.`property` = $param_0 NB: the param_store argument is mutable, it will be updated in this function """ @@ -208,13 +209,13 @@ def _handle_field_filter( """Create a filter for a specific field. Args: - field: name of field - value: value to filter + field: Name of field + value: Value to filter If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns str: Cypher filter snippet @@ -284,16 +285,16 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query + node_alias: Name of the node being filtered in the Cypher query Returns: - str: the Cypher WHERE clause + str: The Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ if not isinstance(filter, dict): - raise ValueError() + raise ValueError(f"Filter must be a dictionary, received {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -326,14 +327,15 @@ def _construct_metadata_filter( return query -def construct_metadata_filter( +def get_metadata_filter( filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS ) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: - filter: a dict of filters - node_alias: the node the filters must be applied on + filter (dict): The filters to be converted to Cypher + node_alias (str): The alias of node the filters must be applied on + in the Cypher query Return: A tuple of str, dict where the string is the cypher query and the dict diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 52bbf332..014ebb4a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -15,7 +15,7 @@ from typing import Optional, Any from neo4j_genai.types import SearchType -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter VECTOR_INDEX_QUERY = ( @@ -72,7 +72,7 @@ def _get_filtered_vector_query( Returns: tuple[str, dict[str, Any]]: query and parameters """ - where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + where_filters, query_params = get_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, embedding_node_property=embedding_node_property, diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index fd562118..b6eb0e63 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -14,124 +14,124 @@ # limitations under the License. import pytest -from neo4j_genai.retrievers.filters import construct_metadata_filter +from neo4j_genai.filters import get_metadata_filter def test_filter_single_field_string(): filters = {"field": "string_value"} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_single_field_int(): filters = {"field": 28} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": 28} def test_filter_single_field_bool(): filters = {"field": False} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": False} def test_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} def test_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <> $param_0" assert params == {"param_0": "string_value"} def test_filter_lt_operator(): filters = {"field": {"$lt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` < $param_0" assert params == {"param_0": 1} def test_filter_gt_operator(): filters = {"field": {"$gt": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` > $param_0" assert params == {"param_0": 1} def test_filter_lte_operator(): filters = {"field": {"$lte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` <= $param_0" assert params == {"param_0": 1} def test_filter_gte_operator(): filters = {"field": {"$gte": 1}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` >= $param_0" assert params == {"param_0": 1} def test_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` NOT IN $param_0" assert params == {"param_0": ["a", "b"]} def test_filter_like_operator(): filters = {"field": {"$like": "some_value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "node.`field` CONTAINS $param_0" assert params == {"param_0": "some_value"} def test_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "toLower(node.`field`) CONTAINS $param_0" assert params == {"param_0": "some value"} def test_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "$param_0 <= node.`field` <= $param_1" assert params == {"param_0": 0, "param_1": 1} def test_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} @@ -143,7 +143,7 @@ def test_filter_and_or_combined(): {"field_3": 11}, ] } - query, params = construct_metadata_filter(filters) + query, params = get_metadata_filter(filters) assert ( query == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" @@ -155,16 +155,16 @@ def test_filter_and_or_combined(): def test_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) def test_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): - construct_metadata_filter(filters) + get_metadata_filter(filters) From e9e321189d5b7584f52f6658999c5388e383bebc Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 17:34:50 +0200 Subject: [PATCH 24/38] Support for filters in Vector and VectorCypher retrievers --- src/neo4j_genai/neo4j_queries.py | 57 +---- src/neo4j_genai/retrievers/filters.py | 315 ++++++++++++++++++++++++++ tests/e2e/conftest.py | 4 +- 3 files changed, 326 insertions(+), 50 deletions(-) create mode 100644 src/neo4j_genai/retrievers/filters.py diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 014ebb4a..1f56d9ea 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -15,7 +15,7 @@ from typing import Optional, Any from neo4j_genai.types import SearchType -from neo4j_genai.filters import get_metadata_filter +from neo4j_genai.retrievers.filters import get_metadata_filter VECTOR_INDEX_QUERY = ( @@ -54,25 +54,8 @@ def _get_hybrid_query() -> str: ) -def _get_filtered_vector_query( - filters: dict[str, Any], - node_label: str, - embedding_node_property: str, - embedding_dimension: int, -) -> tuple[str, dict[str, Any]]: - """Build Cypher query for vector search with filters - Uses exact KNN. - - Args: - filters (dict[str, Any]): filters used to pre-filter the nodes before vector search - node_label (str): node label we want to search for - embedding_node_property (str): the name of the property holding the embeddings - embedding_dimension (int): the dimension of the embeddings - - Returns: - tuple[str, dict[str, Any]]: query and parameters - """ - where_filters, query_params = get_metadata_filter(filters, node_alias="node") +def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: + where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, embedding_node_property=embedding_node_property, @@ -81,37 +64,15 @@ def _get_filtered_vector_query( embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return ( - f"""{base_query} - AND ({where_filters}) - {vector_query} - """, - query_params, - ) + return f"""{base_query} + AND ({where_filters}) + {vector_query} + """, query_params -def _get_vector_query( - filters: Optional[dict[str, Any]], - node_label: str, - embedding_node_property: str, - embedding_dimension: int, -) -> tuple[str, dict[str, Any]]: - """Build the vector query with or without filters - - Args: - filters (dict[str, Any]): filters used to pre-filter the nodes before vector search - node_label (str): node label we want to search for - embedding_node_property (str): the name of the property holding the embeddings - embedding_dimension (int): the dimension of the embeddings - - Returns: - tuple[str, dict[str, Any]]: query and parameters - - """ +def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: if filters: - return _get_filtered_vector_query( - filters, node_label, embedding_node_property, embedding_dimension - ) + return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) return VECTOR_INDEX_QUERY, {} diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py new file mode 100644 index 00000000..358a92fc --- /dev/null +++ b/src/neo4j_genai/retrievers/filters.py @@ -0,0 +1,315 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Filters format: +{"property_name": "property_value"} + + +""" +from typing import Any, Type +from collections import Counter + + +DEFAULT_NODE_ALIAS = "node" + + +class Operator: + """Operator classes are helper classes to build the Cypher queries + from a filter like {"field_name": "field_value"} + They implement two important methods: + - lhs: (left hand side): the node + property to be filtered on + + optional operations on it (see ILikeOperator for instance) + - cleaned_value: a method to make sure the provided parameter values are + consistent with the operator (e.g. LIKE operator only works with string values) + """ + CYPHER_OPERATOR = None + + def __init__(self, node_alias=DEFAULT_NODE_ALIAS): + self.node_alias = node_alias + + def lhs(self, field): + return f"{self.node_alias}.`{field}`" + + def cleaned_value(self, value): + return value + + +class EqOperator(Operator): + CYPHER_OPERATOR = "=" + + +class NeqOperator(Operator): + CYPHER_OPERATOR = "<>" + + +class LtOperator(Operator): + CYPHER_OPERATOR = "<" + + +class GtOperator(Operator): + CYPHER_OPERATOR = ">" + + +class LteOperator(Operator): + CYPHER_OPERATOR = "<=" + + +class GteOperator(Operator): + CYPHER_OPERATOR = ">=" + + +class InOperator(Operator): + CYPHER_OPERATOR = "IN" + + def cleaned_value(self, value): + for val in value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + return value + + +class NinOperator(InOperator): + CYPHER_OPERATOR = "NOT IN" + + +class LikeOperator(Operator): + CYPHER_OPERATOR = "CONTAINS" + + def cleaned_value(self, value): + if not isinstance(value, str): + raise ValueError(f"Expected string value, got {type(value)}: {value}") + return value.rstrip("%") + + +class ILikeOperator(LikeOperator): + + def lhs(self, field): + return f"toLower({self.node_alias}.`{field}`)" + + def cleaned_value(self, value): + value = super().cleaned_value(value) + return value.lower() + + +OPERATOR_PREFIX = "$" + +OPERATOR_EQ = "$eq" +OPERATOR_NE = "$ne" +OPERATOR_LT = "$lt" +OPERATOR_LTE = "$lte" +OPERATOR_GT = "$gt" +OPERATOR_GTE = "$gte" +OPERATOR_BETWEEN = "$between" +OPERATOR_IN = "$in" +OPERATOR_NIN = "$nin" +OPERATOR_LIKE = "$like" +OPERATOR_ILIKE = "$ilike" + +OPERATOR_AND = "$and" +OPERATOR_OR = "$or" + +COMPARISONS_TO_NATIVE = { + OPERATOR_EQ: EqOperator, + OPERATOR_NE: NeqOperator, + OPERATOR_LT: LtOperator, + OPERATOR_LTE: LteOperator, + OPERATOR_GT: GtOperator, + OPERATOR_GTE: GteOperator, + OPERATOR_IN: InOperator, + OPERATOR_NIN: NinOperator, + OPERATOR_LIKE: LikeOperator, + OPERATOR_ILIKE: ILikeOperator, +} + + +LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(LOGICAL_OPERATORS) + .union({OPERATOR_BETWEEN}) +) + + +class ParameterStore: + """ + Store parameters for a given query. + Determine the parameter name depending on a parameter counter + """ + + def __init__(self): + self._counter = Counter() + self.params = {} + + def _get_params_name(self, key="param"): + """NB: the counter parameter is there in purpose, will be modified in the function + to remember the count of each parameter + + :param p: + :param counter: + :return: + """ + # key = slugify(key.replace(".", "_"), separator="_") + param_name = f"{key}_{self._counter[key]}" + self._counter[key] += 1 + return param_name + + def add(self, key, value): + param_name = self._get_params_name() + self.params[param_name] = value + return param_name + + +def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: + """Return Cypher for field operator value + NB: the param_store argument is mutable, it will be updated in this function + """ + native_op = native_operator_class() + param_name = param_store.add(field, native_op.cleaned_value(value)) + query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" + return query_snippet + + +def _handle_field_filter( + field: str, value: Any, param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS +) -> str: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_store: + node_alias: + + Returns + - Cypher filter snippet* + + NB: the param_store argument is mutable, it will be updated in this function + """ + # first, perform some sanity checks + if not isinstance(field, str): + raise ValueError( + f"Field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith(OPERATOR_PREFIX): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") + + if isinstance(value, dict): + # This is a filter specification e.g. {"$gte": 0} + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + operator = operator.lower() + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # if value is not dict, then we assume an equality operator + operator = OPERATOR_EQ + filter_value = value + + # now everything is set, we can start and build the query + # special case for the BETWEEN operator that requires + # two tests (lower_bound <= value <= higher_bound) + if operator == OPERATOR_BETWEEN: + low, high = filter_value + param_name_low = param_store.add(field, low) + param_name_high = param_store.add(field, high) + query_snippet = ( + f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" + ) + return query_snippet + # all the other operators are handled through their own classes: + native_op_class = COMPARISONS_TO_NATIVE[operator] + return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + + +def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: + """Construct a metadata filter. This is a recursive function parsing the filter dict + + Args: + filter: A dictionary representing the filter condition. + param_store: A ParamStore object that will deal with parameter naming and saving along the process + node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + + Returns: + str + + NB: the param_store argument is mutable, it will be updated in this function + """ + + if not isinstance(filter, dict): + raise ValueError() + # if we have more than one entry, this is an implicit "AND" filter + if len(filter) > 1: + return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if not key.startswith("$"): + # it's not an operator, must be a field + return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == OPERATOR_AND: + cypher_operator = " AND " + elif key.lower() == OPERATOR_OR: + cypher_operator = " OR " + else: + raise ValueError(f"Unsupported filter {filter}") + query = cypher_operator.join( + [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + ) + return query + + +def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: + """Construct the cypher filter snippet based on a filter dict + + Args: + filter: a dict of filters + node_alias: the node the filters must be applied on + + Return: + A tuple of str, dict where the string is the cypher query and the dict + contains the query parameters + """ + param_store = ParameterStore() + return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 5e5f3f97..4fdf8574 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -60,7 +60,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=1536, + dimensions=10, similarity_fn="euclidean", ) @@ -70,7 +70,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(1536)] + vector = [random.random() for _ in range(10)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) From f9ea4e7ca36f9397737521ea56d1cdd641d9a676 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:00:38 +0200 Subject: [PATCH 25/38] Ruff --- src/neo4j_genai/neo4j_queries.py | 25 +++++++++--- src/neo4j_genai/retrievers/filters.py | 59 ++++++++++++++++----------- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 1f56d9ea..74f17985 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -54,7 +54,12 @@ def _get_hybrid_query() -> str: ) -def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_filtered_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -64,15 +69,25 @@ def _get_filtered_vector_query(filters: dict[str, Any], node_label: str, embeddi embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return f"""{base_query} + return ( + f"""{base_query} AND ({where_filters}) {vector_query} - """, query_params + """, + query_params, + ) -def _get_vector_query(filters: dict[str, Any], node_label: str, embedding_node_property: str, embedding_dimension: int) -> tuple[str, dict[str, Any]]: +def _get_vector_query( + filters: dict[str, Any], + node_label: str, + embedding_node_property: str, + embedding_dimension: int, +) -> tuple[str, dict[str, Any]]: if filters: - return _get_filtered_vector_query(filters, node_label, embedding_node_property, embedding_dimension) + return _get_filtered_vector_query( + filters, node_label, embedding_node_property, embedding_dimension + ) return VECTOR_INDEX_QUERY, {} diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 358a92fc..0919c237 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -12,12 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Filters format: -{"property_name": "property_value"} - - -""" from typing import Any, Type from collections import Counter @@ -34,6 +28,7 @@ class Operator: - cleaned_value: a method to make sure the provided parameter values are consistent with the operator (e.g. LIKE operator only works with string values) """ + CYPHER_OPERATOR = None def __init__(self, node_alias=DEFAULT_NODE_ALIAS): @@ -96,7 +91,6 @@ def cleaned_value(self, value): class ILikeOperator(LikeOperator): - def lhs(self, field): return f"toLower({self.node_alias}.`{field}`)" @@ -139,9 +133,7 @@ def cleaned_value(self, value): LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(LOGICAL_OPERATORS) - .union({OPERATOR_BETWEEN}) + set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) ) @@ -174,7 +166,13 @@ def add(self, key, value): return param_name -def _single_condition_cypher(field: str, native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, node_alias: str) -> str: +def _single_condition_cypher( + field: str, + native_operator_class: Type[Operator], + value: Any, + param_store: ParameterStore, + node_alias: str, +) -> str: """Return Cypher for field operator value NB: the param_store argument is mutable, it will be updated in this function """ @@ -185,8 +183,10 @@ def _single_condition_cypher(field: str, native_operator_class: Type[Operator], def _handle_field_filter( - field: str, value: Any, param_store: ParameterStore, - node_alias: str = DEFAULT_NODE_ALIAS + field: str, + value: Any, + param_store: ParameterStore, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Create a filter for a specific field. @@ -254,10 +254,14 @@ def _handle_field_filter( return query_snippet # all the other operators are handled through their own classes: native_op_class = COMPARISONS_TO_NATIVE[operator] - return _single_condition_cypher(field, native_op_class, filter_value, param_store, node_alias) + return _single_condition_cypher( + field, native_op_class, filter_value, param_store, node_alias + ) -def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterStore, node_alias: str) -> str: +def _construct_metadata_filter( + filter: dict[str, Any], param_store: ParameterStore, node_alias: str +) -> str: """Construct a metadata filter. This is a recursive function parsing the filter dict Args: @@ -275,19 +279,21 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto raise ValueError() # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: - return _construct_metadata_filter({OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias) + return _construct_metadata_filter( + {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias + ) # The only operators allowed at the top level are $AND and $OR # First check if an operator or a field key, value = list(filter.items())[0] if not key.startswith("$"): # it's not an operator, must be a field - return _handle_field_filter(key, filter[key], param_store, node_alias=node_alias) + return _handle_field_filter( + key, filter[key], param_store, node_alias=node_alias + ) # Here we handle the $and and $or operators if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") if key.lower() == OPERATOR_AND: cypher_operator = " AND " elif key.lower() == OPERATOR_OR: @@ -295,12 +301,17 @@ def _construct_metadata_filter(filter: dict[str, Any], param_store: ParameterSto else: raise ValueError(f"Unsupported filter {filter}") query = cypher_operator.join( - [f"({ _construct_metadata_filter(el, param_store, node_alias)})" for el in value] + [ + f"({ _construct_metadata_filter(el, param_store, node_alias)})" + for el in value + ] ) return query -def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS) -> tuple[str, dict]: +def construct_metadata_filter( + filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS +) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict Args: @@ -312,4 +323,6 @@ def construct_metadata_filter(filter: dict[str, Any], node_alias: str = DEFAULT_ contains the query parameters """ param_store = ParameterStore() - return _construct_metadata_filter(filter, param_store, node_alias=node_alias), param_store.params + return _construct_metadata_filter( + filter, param_store, node_alias=node_alias + ), param_store.params From 1e556a07f726e8ceea51af01a81dbd6acfeafbef Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:04:56 +0200 Subject: [PATCH 26/38] Back to the normal dimension size in e2e tests --- tests/e2e/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 4fdf8574..5e5f3f97 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -60,7 +60,7 @@ def setup_neo4j(driver): vector_index_name, label="Document", property="propertyKey", - dimensions=10, + dimensions=1536, similarity_fn="euclidean", ) @@ -70,7 +70,7 @@ def setup_neo4j(driver): ) # Insert 10 vectors and authors - vector = [random.random() for _ in range(10)] + vector = [random.random() for _ in range(1536)] def random_str(n: int) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(n)]) From cad9d4c91a0de62e1be1ca08e50af4dc984f961d Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 6 May 2024 18:21:49 +0200 Subject: [PATCH 27/38] Improved docstrings + include an example --- src/neo4j_genai/neo4j_queries.py | 30 +++++++++++++++++++--- src/neo4j_genai/retrievers/filters.py | 37 +++++++++++++++++++-------- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 74f17985..7a191b2a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -60,6 +60,18 @@ def _get_filtered_vector_query( embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build Cypher query for vector search with filters + Uses exact KNN. + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + """ where_filters, query_params = construct_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, @@ -71,19 +83,31 @@ def _get_filtered_vector_query( query_params["embedding_dimension"] = embedding_dimension return ( f"""{base_query} - AND ({where_filters}) - {vector_query} + AND ({where_filters}) + {vector_query} """, query_params, ) def _get_vector_query( - filters: dict[str, Any], + filters: Optional[dict[str, Any]], node_label: str, embedding_node_property: str, embedding_dimension: int, ) -> tuple[str, dict[str, Any]]: + """Build the vector query with or without filters + + Args: + filters (dict[str, Any]): filters used to pre-filter the nodes before vector search + node_label (str): node label we want to search for + embedding_node_property (str): the name of the property holding the embeddings + embedding_dimension (int): the dimension of the embeddings + + Returns: + tuple[str, dict[str, Any]]: query and parameters + + """ if filters: return _get_filtered_vector_query( filters, node_label, embedding_node_property, embedding_dimension diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py index 0919c237..fc052fa7 100644 --- a/src/neo4j_genai/retrievers/filters.py +++ b/src/neo4j_genai/retrievers/filters.py @@ -148,8 +148,11 @@ def __init__(self): self.params = {} def _get_params_name(self, key="param"): - """NB: the counter parameter is there in purpose, will be modified in the function - to remember the count of each parameter + """Find parameter name so that param names are unique. + This function adds a suffix to the key corresponding to the number + of times the key have been used in the query. + E.g. + node.age >= $param_0 AND node.age <= $param_1 :param p: :param counter: @@ -161,6 +164,9 @@ def _get_params_name(self, key="param"): return param_name def add(self, key, value): + """This function adds a new parameter to the param dict. + It returns the name of the parameter to be used as a placeholder + in the cypher query, e.g. $param_0""" param_name = self._get_params_name() self.params[param_name] = value return param_name @@ -173,10 +179,21 @@ def _single_condition_cypher( param_store: ParameterStore, node_alias: str, ) -> str: - """Return Cypher for field operator value + """Return Cypher for field operator value. + + Args: + field: the name of the field being filtered + native_operator_class: the operator class that will be used to generate + the Cypher query + value: filtered value + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query + Returns: + str: the Cypher condition, e.g. node.`property` = $param_0 + NB: the param_store argument is mutable, it will be updated in this function """ - native_op = native_operator_class() + native_op = native_operator_class(node_alias=node_alias) param_name = param_store.add(field, native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -196,11 +213,11 @@ def _handle_field_filter( If provided as is then this will be an equality filter If provided as a dictionary then this will be a filter, the key will be the operator and the value will be the value to filter by - param_store: - node_alias: + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns - - Cypher filter snippet* + str: Cypher filter snippet NB: the param_store argument is mutable, it will be updated in this function """ @@ -266,11 +283,11 @@ def _construct_metadata_filter( Args: filter: A dictionary representing the filter condition. - param_store: A ParamStore object that will deal with parameter naming and saving along the process - node_alias: a string used as alias for the node the filters will be applied to (must come from earlier in the query) + param_store: ParameterStore objet that will be updated in this function + node_alias: name of the node being filtered in the Cypher query Returns: - str + str: the Cypher WHERE clause NB: the param_store argument is mutable, it will be updated in this function """ From b1bf005d93fc0bc84b55f65ee748ed69a31769a1 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 13:26:35 +0200 Subject: [PATCH 28/38] Update docstrings, move filters file, rename function --- src/neo4j_genai/neo4j_queries.py | 2 +- src/neo4j_genai/retrievers/filters.py | 345 -------------------------- 2 files changed, 1 insertion(+), 346 deletions(-) delete mode 100644 src/neo4j_genai/retrievers/filters.py diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 7a191b2a..4d90457a 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -72,7 +72,7 @@ def _get_filtered_vector_query( Returns: tuple[str, dict[str, Any]]: query and parameters """ - where_filters, query_params = construct_metadata_filter(filters, node_alias="node") + where_filters, query_params = get_metadata_filter(filters, node_alias="node") base_query = BASE_VECTOR_EXACT_QUERY.format( node_label=node_label, embedding_node_property=embedding_node_property, diff --git a/src/neo4j_genai/retrievers/filters.py b/src/neo4j_genai/retrievers/filters.py deleted file mode 100644 index fc052fa7..00000000 --- a/src/neo4j_genai/retrievers/filters.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Type -from collections import Counter - - -DEFAULT_NODE_ALIAS = "node" - - -class Operator: - """Operator classes are helper classes to build the Cypher queries - from a filter like {"field_name": "field_value"} - They implement two important methods: - - lhs: (left hand side): the node + property to be filtered on - + optional operations on it (see ILikeOperator for instance) - - cleaned_value: a method to make sure the provided parameter values are - consistent with the operator (e.g. LIKE operator only works with string values) - """ - - CYPHER_OPERATOR = None - - def __init__(self, node_alias=DEFAULT_NODE_ALIAS): - self.node_alias = node_alias - - def lhs(self, field): - return f"{self.node_alias}.`{field}`" - - def cleaned_value(self, value): - return value - - -class EqOperator(Operator): - CYPHER_OPERATOR = "=" - - -class NeqOperator(Operator): - CYPHER_OPERATOR = "<>" - - -class LtOperator(Operator): - CYPHER_OPERATOR = "<" - - -class GtOperator(Operator): - CYPHER_OPERATOR = ">" - - -class LteOperator(Operator): - CYPHER_OPERATOR = "<=" - - -class GteOperator(Operator): - CYPHER_OPERATOR = ">=" - - -class InOperator(Operator): - CYPHER_OPERATOR = "IN" - - def cleaned_value(self, value): - for val in value: - if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) - return value - - -class NinOperator(InOperator): - CYPHER_OPERATOR = "NOT IN" - - -class LikeOperator(Operator): - CYPHER_OPERATOR = "CONTAINS" - - def cleaned_value(self, value): - if not isinstance(value, str): - raise ValueError(f"Expected string value, got {type(value)}: {value}") - return value.rstrip("%") - - -class ILikeOperator(LikeOperator): - def lhs(self, field): - return f"toLower({self.node_alias}.`{field}`)" - - def cleaned_value(self, value): - value = super().cleaned_value(value) - return value.lower() - - -OPERATOR_PREFIX = "$" - -OPERATOR_EQ = "$eq" -OPERATOR_NE = "$ne" -OPERATOR_LT = "$lt" -OPERATOR_LTE = "$lte" -OPERATOR_GT = "$gt" -OPERATOR_GTE = "$gte" -OPERATOR_BETWEEN = "$between" -OPERATOR_IN = "$in" -OPERATOR_NIN = "$nin" -OPERATOR_LIKE = "$like" -OPERATOR_ILIKE = "$ilike" - -OPERATOR_AND = "$and" -OPERATOR_OR = "$or" - -COMPARISONS_TO_NATIVE = { - OPERATOR_EQ: EqOperator, - OPERATOR_NE: NeqOperator, - OPERATOR_LT: LtOperator, - OPERATOR_LTE: LteOperator, - OPERATOR_GT: GtOperator, - OPERATOR_GTE: GteOperator, - OPERATOR_IN: InOperator, - OPERATOR_NIN: NinOperator, - OPERATOR_LIKE: LikeOperator, - OPERATOR_ILIKE: ILikeOperator, -} - - -LOGICAL_OPERATORS = {OPERATOR_AND, OPERATOR_OR} - -SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE).union(LOGICAL_OPERATORS).union({OPERATOR_BETWEEN}) -) - - -class ParameterStore: - """ - Store parameters for a given query. - Determine the parameter name depending on a parameter counter - """ - - def __init__(self): - self._counter = Counter() - self.params = {} - - def _get_params_name(self, key="param"): - """Find parameter name so that param names are unique. - This function adds a suffix to the key corresponding to the number - of times the key have been used in the query. - E.g. - node.age >= $param_0 AND node.age <= $param_1 - - :param p: - :param counter: - :return: - """ - # key = slugify(key.replace(".", "_"), separator="_") - param_name = f"{key}_{self._counter[key]}" - self._counter[key] += 1 - return param_name - - def add(self, key, value): - """This function adds a new parameter to the param dict. - It returns the name of the parameter to be used as a placeholder - in the cypher query, e.g. $param_0""" - param_name = self._get_params_name() - self.params[param_name] = value - return param_name - - -def _single_condition_cypher( - field: str, - native_operator_class: Type[Operator], - value: Any, - param_store: ParameterStore, - node_alias: str, -) -> str: - """Return Cypher for field operator value. - - Args: - field: the name of the field being filtered - native_operator_class: the operator class that will be used to generate - the Cypher query - value: filtered value - param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query - Returns: - str: the Cypher condition, e.g. node.`property` = $param_0 - - NB: the param_store argument is mutable, it will be updated in this function - """ - native_op = native_operator_class(node_alias=node_alias) - param_name = param_store.add(field, native_op.cleaned_value(value)) - query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" - return query_snippet - - -def _handle_field_filter( - field: str, - value: Any, - param_store: ParameterStore, - node_alias: str = DEFAULT_NODE_ALIAS, -) -> str: - """Create a filter for a specific field. - - Args: - field: name of field - value: value to filter - If provided as is then this will be an equality filter - If provided as a dictionary then this will be a filter, the key - will be the operator and the value will be the value to filter by - param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query - - Returns - str: Cypher filter snippet - - NB: the param_store argument is mutable, it will be updated in this function - """ - # first, perform some sanity checks - if not isinstance(field, str): - raise ValueError( - f"Field should be a string but got: {type(field)} with value: {field}" - ) - - if field.startswith(OPERATOR_PREFIX): - raise ValueError( - f"Invalid filter condition. Expected a field but got an operator: " - f"{field}" - ) - - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") - - if isinstance(value, dict): - # This is a filter specification e.g. {"$gte": 0} - if len(value) != 1: - raise ValueError( - "Invalid filter condition. Expected a value which " - "is a dictionary with a single key that corresponds to an operator " - f"but got a dictionary with {len(value)} keys. The first few " - f"keys are: {list(value.keys())[:3]}" - ) - operator, filter_value = list(value.items())[0] - operator = operator.lower() - # Verify that that operator is an operator - if operator not in SUPPORTED_OPERATORS: - raise ValueError( - f"Invalid operator: {operator}. " - f"Expected one of {SUPPORTED_OPERATORS}" - ) - else: # if value is not dict, then we assume an equality operator - operator = OPERATOR_EQ - filter_value = value - - # now everything is set, we can start and build the query - # special case for the BETWEEN operator that requires - # two tests (lower_bound <= value <= higher_bound) - if operator == OPERATOR_BETWEEN: - low, high = filter_value - param_name_low = param_store.add(field, low) - param_name_high = param_store.add(field, high) - query_snippet = ( - f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" - ) - return query_snippet - # all the other operators are handled through their own classes: - native_op_class = COMPARISONS_TO_NATIVE[operator] - return _single_condition_cypher( - field, native_op_class, filter_value, param_store, node_alias - ) - - -def _construct_metadata_filter( - filter: dict[str, Any], param_store: ParameterStore, node_alias: str -) -> str: - """Construct a metadata filter. This is a recursive function parsing the filter dict - - Args: - filter: A dictionary representing the filter condition. - param_store: ParameterStore objet that will be updated in this function - node_alias: name of the node being filtered in the Cypher query - - Returns: - str: the Cypher WHERE clause - - NB: the param_store argument is mutable, it will be updated in this function - """ - - if not isinstance(filter, dict): - raise ValueError() - # if we have more than one entry, this is an implicit "AND" filter - if len(filter) > 1: - return _construct_metadata_filter( - {OPERATOR_AND: [{k: v} for k, v in filter.items()]}, param_store, node_alias - ) - # The only operators allowed at the top level are $AND and $OR - # First check if an operator or a field - key, value = list(filter.items())[0] - if not key.startswith("$"): - # it's not an operator, must be a field - return _handle_field_filter( - key, filter[key], param_store, node_alias=node_alias - ) - - # Here we handle the $and and $or operators - if not isinstance(value, list): - raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") - if key.lower() == OPERATOR_AND: - cypher_operator = " AND " - elif key.lower() == OPERATOR_OR: - cypher_operator = " OR " - else: - raise ValueError(f"Unsupported filter {filter}") - query = cypher_operator.join( - [ - f"({ _construct_metadata_filter(el, param_store, node_alias)})" - for el in value - ] - ) - return query - - -def construct_metadata_filter( - filter: dict[str, Any], node_alias: str = DEFAULT_NODE_ALIAS -) -> tuple[str, dict]: - """Construct the cypher filter snippet based on a filter dict - - Args: - filter: a dict of filters - node_alias: the node the filters must be applied on - - Return: - A tuple of str, dict where the string is the cypher query and the dict - contains the query parameters - """ - param_store = ParameterStore() - return _construct_metadata_filter( - filter, param_store, node_alias=node_alias - ), param_store.params From ae072fbcc0e03def117f367684bb9faaf2ac90e1 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:13:52 +0200 Subject: [PATCH 29/38] More unit tests --- src/neo4j_genai/filters.py | 10 +- tests/unit/retrievers/test_filters.py | 332 +++++++++++++++++++++++++- 2 files changed, 334 insertions(+), 8 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 16699cc2..2f301ed6 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -71,9 +71,7 @@ class InOperator(Operator): def cleaned_value(self, value): for val in value: if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) + raise ValueError(f"Unsupported type: {type(val)} for value: {val}") return value @@ -178,7 +176,7 @@ def _single_condition_cypher( native_operator_class: Type[Operator], value: Any, param_store: ParameterStore, - node_alias: str, + node_alias: str = DEFAULT_NODE_ALIAS, ) -> str: """Return Cypher for field operator value. @@ -263,6 +261,10 @@ def _handle_field_filter( # special case for the BETWEEN operator that requires # two tests (lower_bound <= value <= higher_bound) if operator == OPERATOR_BETWEEN: + if len(filter_value) != 2: + raise ValueError( + f"Expected lower and upper bounds in a list, got {filter_value}" + ) low, high = filter_value param_name_low = param_store.add(field, low) param_name_high = param_store.add(field, high) diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/retrievers/test_filters.py index b6eb0e63..0124eefe 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/retrievers/test_filters.py @@ -12,9 +12,324 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import pytest -from neo4j_genai.filters import get_metadata_filter +from neo4j_genai.filters import ( + get_metadata_filter, + _single_condition_cypher, + _handle_field_filter, + _construct_metadata_filter, + EqOperator, + NeqOperator, + LtOperator, + GtOperator, + LteOperator, + GteOperator, + InOperator, + NinOperator, + LikeOperator, + ILikeOperator, + ParameterStore, +) + + +@pytest.fixture(scope="function") +def param_store_empty(): + return ParameterStore() + + +def test_param_store(): + ps = ParameterStore() + assert ps.params == {} + ps.add("", 1) + assert ps.params == {"param_0": 1} + ps.add("", "some value") + assert ps.params == {"param_0": 1, "param_1": "some value"} + + +def test_single_condition_cypher_eq(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_eq_node_alias(param_store_empty): + generated = _single_condition_cypher( + "field", EqOperator, "value", node_alias="n", param_store=param_store_empty + ) + assert generated == "n.`field` = $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_neq(param_store_empty): + generated = _single_condition_cypher( + "field", NeqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` <> $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_lt(param_store_empty): + generated = _single_condition_cypher( + "field", LtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` < $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gt(param_store_empty): + generated = _single_condition_cypher( + "field", GtOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` > $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_lte(param_store_empty): + generated = _single_condition_cypher( + "field", LteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` <= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_gte(param_store_empty): + generated = _single_condition_cypher( + "field", GteOperator, 10, param_store=param_store_empty + ) + assert generated == "node.`field` >= $param_0" + assert param_store_empty.params == {"param_0": 10} + + +def test_single_condition_cypher_in_int(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, [1, 2, 3], param_store=param_store_empty + ) + assert generated == "node.`field` IN $param_0" + assert param_store_empty.params == {"param_0": [1, 2, 3]} + + +def test_single_condition_cypher_in_str(param_store_empty): + generated = _single_condition_cypher( + "field", InOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.`field` IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_in_invalid_type(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", + InOperator, + [ + {"my_tuple"}, + ], + param_store=param_store_empty, + ) + assert "Unsupported type: " in str(excinfo) + + +def test_single_condition_cypher_nin(param_store_empty): + generated = _single_condition_cypher( + "field", NinOperator, ["a", "b", "c"], param_store=param_store_empty + ) + assert generated == "node.`field` NOT IN $param_0" + assert param_store_empty.params == {"param_0": ["a", "b", "c"]} + + +def test_single_condition_cypher_like(param_store_empty): + generated = _single_condition_cypher( + "field", LikeOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`field` CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "value"} + + +def test_single_condition_cypher_ilike(param_store_empty): + generated = _single_condition_cypher( + "field", ILikeOperator, "My Value", param_store=param_store_empty + ) + assert generated == "toLower(node.`field`) CONTAINS $param_0" + assert param_store_empty.params == {"param_0": "my value"} + + +def test_single_condition_cypher_like_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _single_condition_cypher( + "field", ILikeOperator, 1, param_store=param_store_empty + ) + assert "Expected string value, got " in str(excinfo) + + +def test_handle_field_filter_not_a_string(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter(1, "value", param_store=param_store_empty) + assert "Field should be a string but got: with value: 1" in str( + excinfo + ) + + +def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter("$field_name", "value", param_store=param_store_empty) + assert ( + "Invalid filter condition. Expected a field but got an operator: $field_name" + in str(excinfo) + ) + + +def test_handle_field_filter_bad_field_name(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter("bad+field?name", "value", param_store=param_store_empty) + assert "Invalid field name: bad+field?name. Expected a valid identifier." in str( + excinfo + ) + + +def test_handle_field_filter_bad_value(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={"operator1": "value1", "operator2": "value2"}, + param_store=param_store_empty, + ) + assert "Invalid filter condition" in str(excinfo) + + +def test_handle_field_filter_bad_operator_name(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", value={"$invalid": "value"}, param_store=param_store_empty + ) + assert "Invalid operator: $invalid" in str(excinfo) + + +def test_handle_field_filter_operator_between(param_store_empty): + generated = _handle_field_filter( + "field", value={"$between": [0, 1]}, param_store=param_store_empty + ) + assert generated == "$param_0 <= node.`field` <= $param_1" + assert param_store_empty.params == {"param_0": 0, "param_1": 1} + + +def test_handle_field_filter_operator_between_not_enough_parameters(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _handle_field_filter( + "field", + value={ + "$between": [ + 0, + ] + }, + param_store=param_store_empty, + ) + assert "Expected lower and upper bounds in a list, got [0]" in str(excinfo) + + +@patch("neo4j_genai.filters._single_condition_cypher", return_value="condition") +def test_handle_field_filter_implicit_eq( + _single_condition_cypher_mocked, param_store_empty +): + generated = _handle_field_filter( + "field", value="some_value", param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + assert generated == "condition" + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_eq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$eq": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", EqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_neq(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ne": "some_value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NeqOperator, "some_value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gt(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gt": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GtOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_lte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$lte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_gte(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$gte": 1}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", GteOperator, 1, param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_in(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$in": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", InOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_nin(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter("field", value={"$nin": [1, 2]}, param_store=param_store_empty) + _single_condition_cypher_mocked.assert_called_once_with( + "field", NinOperator, [1, 2], param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_like(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$like": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", LikeOperator, "value", param_store_empty, "node" + ) + + +@patch("neo4j_genai.filters._single_condition_cypher") +def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_empty): + _handle_field_filter( + "field", value={"$ilike": "value"}, param_store=param_store_empty + ) + _single_condition_cypher_mocked.assert_called_once_with( + "field", ILikeOperator, "value", param_store_empty, "node" + ) def test_filter_single_field_string(): @@ -129,6 +444,15 @@ def test_filter_explicit_and_condition(): assert params == {"param_0": "string_value", "param_1": True} +def test_filter_explicit_and_condition_with_operator(): + filters = { + "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] + } + query, params = get_metadata_filter(filters) + assert query == "(node.`field_1` <> $param_0) AND (node.`field_2` IN $param_1)" + assert params == {"param_0": "string_value", "param_1": [1, 2]} + + def test_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) @@ -144,9 +468,9 @@ def test_filter_and_or_combined(): ] } query, params = get_metadata_filter(filters) - assert ( - query - == "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) AND (node.`field_3` = $param_2)" + assert query == ( + "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) " + "AND (node.`field_3` = $param_2)" ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} From 6ee5ab2c2b17cb2eb48ab5e3649fcaa84057e294 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:47:17 +0200 Subject: [PATCH 30/38] More unit tests for filters --- src/neo4j_genai/filters.py | 4 +- tests/unit/{retrievers => }/test_filters.py | 95 ++++++++++++++++----- 2 files changed, 74 insertions(+), 25 deletions(-) rename tests/unit/{retrievers => }/test_filters.py (81%) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 2f301ed6..ebb062d0 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -296,7 +296,7 @@ def _construct_metadata_filter( """ if not isinstance(filter, dict): - raise ValueError(f"Filter must be a dictionary, received {type(filter)}") + raise ValueError(f"Filter must be a dictionary, got {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -319,7 +319,7 @@ def _construct_metadata_filter( elif key.lower() == OPERATOR_OR: cypher_operator = " OR " else: - raise ValueError(f"Unsupported filter {filter}") + raise ValueError(f"Unsupported operator: {key}") query = cypher_operator.join( [ f"({ _construct_metadata_filter(el, param_store, node_alias)})" diff --git a/tests/unit/retrievers/test_filters.py b/tests/unit/test_filters.py similarity index 81% rename from tests/unit/retrievers/test_filters.py rename to tests/unit/test_filters.py index 0124eefe..2fba20b9 100644 --- a/tests/unit/retrievers/test_filters.py +++ b/tests/unit/test_filters.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest.mock import patch, call import pytest @@ -332,119 +332,168 @@ def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_ ) -def test_filter_single_field_string(): +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_filter_is_not_a_dict(_handle_field_filter_mock, param_store_empty): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter([], param_store_empty, node_alias="n") + assert "Filter must be a dictionary, got " in str(excinfo) + + +@patch("neo4j_genai.filters._handle_field_filter") +def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_store_empty): + _construct_metadata_filter({"field": "value"}, param_store_empty, node_alias="n") + _handle_field_filter_mock.assert_called_once_with( + "field", "value", param_store_empty, node_alias="n" + ) + + +@patch("neo4j_genai.filters._construct_metadata_filter") +def test_construct_metadata_filter_implicit_and(_construct_metadata_filter_mock, param_store_empty): + _construct_metadata_filter({"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, "n"), + ]) + + +@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) +def test_construct_metadata_filter_explicit_and(_construct_metadata_filter_mock, param_store_empty): + generated = _construct_metadata_filter({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n") + ]) + assert generated == "(filter1) AND (filter2)" + + +@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) +def test_construct_metadata_filter_or(_construct_metadata_filter_mock, param_store_empty): + generated = _construct_metadata_filter({"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") + _construct_metadata_filter_mock.assert_has_calls([ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n") + ]) + assert generated == "(filter1) OR (filter2)" + + +def test_construct_metadata_filter_invalid_operator(param_store_empty): + with pytest.raises(ValueError) as excinfo: + _construct_metadata_filter({"$invalid": [{}, {}]}, param_store_empty, node_alias="n") + assert "Unsupported operator: $invalid" in str(excinfo) + + +def test_get_metadata_filter_single_field_string(): filters = {"field": "string_value"} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} -def test_filter_single_field_int(): +def test_get_metadata_filter_single_field_int(): filters = {"field": 28} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": 28} -def test_filter_single_field_bool(): +def test_get_metadata_filter_single_field_bool(): filters = {"field": False} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": False} -def test_filter_explicit_eq_operator(): +def test_get_metadata_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` = $param_0" assert params == {"param_0": "string_value"} -def test_filter_neq_operator(): +def test_get_metadata_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` <> $param_0" assert params == {"param_0": "string_value"} -def test_filter_lt_operator(): +def test_get_metadata_filter_lt_operator(): filters = {"field": {"$lt": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` < $param_0" assert params == {"param_0": 1} -def test_filter_gt_operator(): +def test_get_metadata_filter_gt_operator(): filters = {"field": {"$gt": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` > $param_0" assert params == {"param_0": 1} -def test_filter_lte_operator(): +def test_get_metadata_filter_lte_operator(): filters = {"field": {"$lte": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` <= $param_0" assert params == {"param_0": 1} -def test_filter_gte_operator(): +def test_get_metadata_filter_gte_operator(): filters = {"field": {"$gte": 1}} query, params = get_metadata_filter(filters) assert query == "node.`field` >= $param_0" assert params == {"param_0": 1} -def test_filter_in_operator(): +def test_get_metadata_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} query, params = get_metadata_filter(filters) assert query == "node.`field` IN $param_0" assert params == {"param_0": ["a", "b"]} -def test_filter_not_in_operator(): +def test_get_metadata_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} query, params = get_metadata_filter(filters) assert query == "node.`field` NOT IN $param_0" assert params == {"param_0": ["a", "b"]} -def test_filter_like_operator(): +def test_get_metadata_filter_like_operator(): filters = {"field": {"$like": "some_value"}} query, params = get_metadata_filter(filters) assert query == "node.`field` CONTAINS $param_0" assert params == {"param_0": "some_value"} -def test_filter_ilike_operator(): +def test_get_metadata_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} query, params = get_metadata_filter(filters) assert query == "toLower(node.`field`) CONTAINS $param_0" assert params == {"param_0": "some value"} -def test_filter_between_operator(): +def test_get_metadata_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} query, params = get_metadata_filter(filters) assert query == "$param_0 <= node.`field` <= $param_1" assert params == {"param_0": 0, "param_1": 1} -def test_filter_implicit_and_condition(): +def test_get_metadata_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_explicit_and_condition(): +def test_get_metadata_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_explicit_and_condition_with_operator(): +def test_get_metadata_filter_explicit_and_condition_with_operator(): filters = { "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] } @@ -453,14 +502,14 @@ def test_filter_explicit_and_condition_with_operator(): assert params == {"param_0": "string_value", "param_1": [1, 2]} -def test_filter_or_condition(): +def test_get_metadata_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" assert params == {"param_0": "string_value", "param_1": True} -def test_filter_and_or_combined(): +def test_get_metadata_filter_and_or_combined(): filters = { "$and": [ {"$or": [{"field_1": "string_value"}, {"field_2": True}]}, @@ -476,19 +525,19 @@ def test_filter_and_or_combined(): # now testing bad filters -def test_field_name_with_dollar_sign(): +def test_get_metadata_filter_field_name_with_dollar_sign(): filters = {"$field": "value"} with pytest.raises(ValueError): get_metadata_filter(filters) -def test_and_no_list(): +def test_get_metadata_filter_and_no_list(): filters = {"$and": {}} with pytest.raises(ValueError): get_metadata_filter(filters) -def test_unsupported_operator(): +def test_get_metadata_filter_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} with pytest.raises(ValueError): get_metadata_filter(filters) From 2daa2a333685d0da3a830f01502ceb2828b1657b Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 15:57:28 +0200 Subject: [PATCH 31/38] Increase test coverage for queries --- src/neo4j_genai/neo4j_queries.py | 10 +--- tests/unit/test_filters.py | 84 +++++++++++++++++++++++--------- tests/unit/test_neo4j_queries.py | 50 +++++++++++++++++++ 3 files changed, 112 insertions(+), 32 deletions(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index 4d90457a..d7f3af58 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -81,13 +81,7 @@ def _get_filtered_vector_query( embedding_node_property=embedding_node_property, ) query_params["embedding_dimension"] = embedding_dimension - return ( - f"""{base_query} - AND ({where_filters}) - {vector_query} - """, - query_params, - ) + return f"{base_query} AND ({where_filters}) {vector_query}", query_params def _get_vector_query( @@ -155,7 +149,7 @@ def get_search_query( query_tail = _get_query_tail( retrieval_query, return_properties, fallback_return="RETURN node, score" ) - return " ".join([query, query_tail]), params + return f"{query} {query_tail}", params def _get_query_tail( diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 2fba20b9..8dff3570 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -333,14 +333,18 @@ def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_ @patch("neo4j_genai.filters._handle_field_filter") -def test_construct_metadata_filter_filter_is_not_a_dict(_handle_field_filter_mock, param_store_empty): +def test_construct_metadata_filter_filter_is_not_a_dict( + _handle_field_filter_mock, param_store_empty +): with pytest.raises(ValueError) as excinfo: _construct_metadata_filter([], param_store_empty, node_alias="n") assert "Filter must be a dictionary, got " in str(excinfo) @patch("neo4j_genai.filters._handle_field_filter") -def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_store_empty): +def test_construct_metadata_filter_no_operator( + _handle_field_filter_mock, param_store_empty +): _construct_metadata_filter({"field": "value"}, param_store_empty, node_alias="n") _handle_field_filter_mock.assert_called_once_with( "field", "value", param_store_empty, node_alias="n" @@ -348,36 +352,68 @@ def test_construct_metadata_filter_no_operator(_handle_field_filter_mock, param_ @patch("neo4j_genai.filters._construct_metadata_filter") -def test_construct_metadata_filter_implicit_and(_construct_metadata_filter_mock, param_store_empty): - _construct_metadata_filter({"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, "n"), - ]) - - -@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) -def test_construct_metadata_filter_explicit_and(_construct_metadata_filter_mock, param_store_empty): - generated = _construct_metadata_filter({"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"field_1": "value_1"}, param_store_empty, "n"), - call({"field_2": "value_2"}, param_store_empty, "n") - ]) +def test_construct_metadata_filter_implicit_and( + _construct_metadata_filter_mock, param_store_empty +): + _construct_metadata_filter( + {"field_1": "value_1", "field_2": "value_2"}, param_store_empty, node_alias="n" + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + "n", + ), + ] + ) + + +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_explicit_and( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$and": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) assert generated == "(filter1) AND (filter2)" -@patch("neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"]) -def test_construct_metadata_filter_or(_construct_metadata_filter_mock, param_store_empty): - generated = _construct_metadata_filter({"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, param_store_empty, node_alias="n") - _construct_metadata_filter_mock.assert_has_calls([ - call({"field_1": "value_1"}, param_store_empty, "n"), - call({"field_2": "value_2"}, param_store_empty, "n") - ]) +@patch( + "neo4j_genai.filters._construct_metadata_filter", side_effect=["filter1", "filter2"] +) +def test_construct_metadata_filter_or( + _construct_metadata_filter_mock, param_store_empty +): + generated = _construct_metadata_filter( + {"$or": [{"field_1": "value_1"}, {"field_2": "value_2"}]}, + param_store_empty, + node_alias="n", + ) + _construct_metadata_filter_mock.assert_has_calls( + [ + call({"field_1": "value_1"}, param_store_empty, "n"), + call({"field_2": "value_2"}, param_store_empty, "n"), + ] + ) assert generated == "(filter1) OR (filter2)" def test_construct_metadata_filter_invalid_operator(param_store_empty): with pytest.raises(ValueError) as excinfo: - _construct_metadata_filter({"$invalid": [{}, {}]}, param_store_empty, node_alias="n") + _construct_metadata_filter( + {"$invalid": [{}, {}]}, param_store_empty, node_alias="n" + ) assert "Unsupported operator: $invalid" in str(excinfo) diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 0d420c51..0ef2b68e 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch from neo4j_genai.neo4j_queries import get_search_query, _get_query_tail from neo4j_genai.types import SearchType @@ -68,6 +69,55 @@ def test_vector_search_with_retrieval_query(): assert result.strip() == expected.strip() +@patch("neo4j_genai.neo4j_queries.get_metadata_filter", return_value=["True", {}]) +def test_vector_search_with_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) + assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1} + + +@patch( + "neo4j_genai.neo4j_queries.get_metadata_filter", + return_value=["True", {"param": "value"}], +) +def test_vector_search_with_params_from_filters(_mock): + expected = ( + "MATCH (node:`Label`) " + "WHERE node.`vector` IS NOT NULL " + "AND size(node.`vector`) = toInteger($embedding_dimension)" + " AND (True) " + "WITH node, " + "vector.similarity.cosine(node.`vector`, $query_vector) AS score " + "ORDER BY score DESC LIMIT $top_k" + " RETURN node, score" + ) + result, params = get_search_query( + SearchType.VECTOR, + node_label="Label", + embedding_node_property="vector", + embedding_dimension=1, + filters={"field": "value"}, + ) + assert result.strip() == expected.strip() + assert params == {"embedding_dimension": 1, "param": "value"} + + def test_hybrid_search_with_retrieval_query(): retrieval_query = "MATCH (n) RETURN n LIMIT 10" expected = ( From a5e8980fef277fb5ab778481998cfe3b3ec15779 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 7 May 2024 19:25:48 +0200 Subject: [PATCH 32/38] Simplification, formatting --- src/neo4j_genai/filters.py | 19 ++++++++----------- src/neo4j_genai/retrievers/base.py | 12 +++++++----- tests/unit/test_filters.py | 19 +++++++++---------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index ebb062d0..100b9bd0 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -35,7 +35,8 @@ def __init__(self, node_alias=DEFAULT_NODE_ALIAS): self.node_alias = node_alias def lhs(self, field): - return f"{self.node_alias}.`{field}`" + escaped_field = field.replace("`", "``") + return f"{self.node_alias}.`{escaped_field}`" def cleaned_value(self, value): return value @@ -145,7 +146,7 @@ def __init__(self): self._counter = Counter() self.params = {} - def _get_params_name(self, key="param"): + def _get_params_name(self): """Find parameter name so that param names are unique. This function adds a suffix to the key corresponding to the number of times the key have been used in the query. @@ -157,12 +158,12 @@ def _get_params_name(self, key="param"): Returns: The full unique parameter name """ - # key = slugify(key.replace(".", "_"), separator="_") + key = "param" param_name = f"{key}_{self._counter[key]}" self._counter[key] += 1 return param_name - def add(self, key, value): + def add(self, value): """This function adds a new parameter to the param dict. It returns the name of the parameter to be used as a placeholder in the cypher query, e.g. $param_0""" @@ -193,7 +194,7 @@ def _single_condition_cypher( NB: the param_store argument is mutable, it will be updated in this function """ native_op = native_operator_class(node_alias=node_alias) - param_name = param_store.add(field, native_op.cleaned_value(value)) + param_name = param_store.add(native_op.cleaned_value(value)) query_snippet = f"{native_op.lhs(field)} {native_op.CYPHER_OPERATOR} ${param_name}" return query_snippet @@ -232,10 +233,6 @@ def _handle_field_filter( f"{field}" ) - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") - if isinstance(value, dict): # This is a filter specification e.g. {"$gte": 0} if len(value) != 1: @@ -266,8 +263,8 @@ def _handle_field_filter( f"Expected lower and upper bounds in a list, got {filter_value}" ) low, high = filter_value - param_name_low = param_store.add(field, low) - param_name_high = param_store.add(field, high) + param_name_low = param_store.add(low) + param_name_high = param_store.add(high) query_snippet = ( f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" ) diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index e478e3a5..cbf8abb3 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -60,11 +60,13 @@ def search(self, *args, **kwargs) -> Any: def _fetch_index_infos(self): """Fetch the node label and embedding property from the index definition""" - query = """SHOW VECTOR INDEXES -YIELD name, labelsOrTypes, properties, options -WHERE name = $index_name -RETURN labelsOrTypes as labels, properties, options.indexConfig.`vector.dimensions` as dimensions - """ + query = ( + "SHOW VECTOR INDEXES " + "YIELD name, labelsOrTypes, properties, options " + "WHERE name = $index_name " + "RETURN labelsOrTypes as labels, properties, " + "options.indexConfig.`vector.dimensions` as dimensions" + ) result = self.driver.execute_query(query, {"index_name": self.index_name}) try: result = result.records[0] diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 8dff3570..9979846e 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -43,9 +43,9 @@ def param_store_empty(): def test_param_store(): ps = ParameterStore() assert ps.params == {} - ps.add("", 1) + ps.add(1) assert ps.params == {"param_0": 1} - ps.add("", "some value") + ps.add("some value") assert ps.params == {"param_0": 1, "param_1": "some value"} @@ -166,6 +166,13 @@ def test_single_condition_cypher_like_not_a_string(param_store_empty): assert "Expected string value, got " in str(excinfo) +def test_single_condition_cypher_escaped_field_name(param_store_empty): + generated = _single_condition_cypher( + "na`me", EqOperator, "value", param_store=param_store_empty + ) + assert generated == "node.`na``me` = $param_0" + + def test_handle_field_filter_not_a_string(param_store_empty): with pytest.raises(ValueError) as excinfo: _handle_field_filter(1, "value", param_store=param_store_empty) @@ -183,14 +190,6 @@ def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): ) -def test_handle_field_filter_bad_field_name(param_store_empty): - with pytest.raises(ValueError) as excinfo: - _handle_field_filter("bad+field?name", "value", param_store=param_store_empty) - assert "Invalid field name: bad+field?name. Expected a valid identifier." in str( - excinfo - ) - - def test_handle_field_filter_bad_value(param_store_empty): with pytest.raises(ValueError) as excinfo: _handle_field_filter( From 6d1991984a2c009c8a7d45dca9793a4d3a2207f3 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 13 May 2024 11:55:26 +0200 Subject: [PATCH 33/38] Fix path after merge --- src/neo4j_genai/neo4j_queries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_genai/neo4j_queries.py b/src/neo4j_genai/neo4j_queries.py index d7f3af58..e974a047 100644 --- a/src/neo4j_genai/neo4j_queries.py +++ b/src/neo4j_genai/neo4j_queries.py @@ -15,7 +15,7 @@ from typing import Optional, Any from neo4j_genai.types import SearchType -from neo4j_genai.retrievers.filters import get_metadata_filter +from neo4j_genai.filters import get_metadata_filter VECTOR_INDEX_QUERY = ( From 2b89aff3865e3987c89bf5de0947288559c21c70 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 13 May 2024 12:51:23 +0200 Subject: [PATCH 34/38] Use backticks only if field is non identifier --- src/neo4j_genai/filters.py | 26 ++++++++++-- tests/unit/test_filters.py | 86 +++++++++++++++++++++++--------------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 100b9bd0..56507afd 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -34,9 +34,26 @@ class Operator: def __init__(self, node_alias=DEFAULT_NODE_ALIAS): self.node_alias = node_alias + @staticmethod + def safe_field_cypher(field_name: str) -> str: + """This method must be used to escape a field name if + necessary to build a valid Cypher query. See: + https://neo4j.com/docs/cypher-manual/current/syntax/naming/ + + Args: + field_name (str): The initial unescaped field name + + Returns: + The field name potentially surrounded with backticks if needed + """ + if field_name.isidentifier(): + return field_name + escaped_field = field_name.replace("`", "``") + return f"`{escaped_field}`" + def lhs(self, field): - escaped_field = field.replace("`", "``") - return f"{self.node_alias}.`{escaped_field}`" + safe_field_cypher = self.safe_field_cypher(field) + return f"{self.node_alias}.{safe_field_cypher}" def cleaned_value(self, value): return value @@ -91,7 +108,8 @@ def cleaned_value(self, value): class ILikeOperator(LikeOperator): def lhs(self, field): - return f"toLower({self.node_alias}.`{field}`)" + safe_field_cypher = self.safe_field_cypher(field) + return f"toLower({self.node_alias}.{safe_field_cypher})" def cleaned_value(self, value): value = super().cleaned_value(value) @@ -266,7 +284,7 @@ def _handle_field_filter( param_name_low = param_store.add(low) param_name_high = param_store.add(high) query_snippet = ( - f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.`{field}` <= ${param_name_high}" + f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.{Operator.safe_field_cypher(field)} <= ${param_name_high}" ) return query_snippet # all the other operators are handled through their own classes: diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 9979846e..66d3d6c8 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -21,6 +21,7 @@ _single_condition_cypher, _handle_field_filter, _construct_metadata_filter, + Operator, EqOperator, NeqOperator, LtOperator, @@ -49,11 +50,30 @@ def test_param_store(): assert ps.params == {"param_0": 1, "param_1": "some value"} +def test_operator_field_escape(): + assert Operator.safe_field_cypher("name") == "name" + assert Operator.safe_field_cypher("_name") == "_name" + assert Operator.safe_field_cypher("na_me123") == "na_me123" + # escape if using separators different from underscore + assert Operator.safe_field_cypher("na-me") == "`na-me`" + assert Operator.safe_field_cypher("na me") == "`na me`" + assert Operator.safe_field_cypher("na.me") == "`na.me`" + # escape if name starts with a non alpha character + assert Operator.safe_field_cypher("1name") == "`1name`" + assert Operator.safe_field_cypher("?name") == "`?name`" + # escape if name contains special characters + assert Operator.safe_field_cypher("n*ame") == "`n*ame`" + assert Operator.safe_field_cypher("na_me123%") == "`na_me123%`" + assert Operator.safe_field_cypher("\name") == "`\name`" + # escape the escape character + assert Operator.safe_field_cypher("na`me") == "`na``me`" + + def test_single_condition_cypher_eq(param_store_empty): generated = _single_condition_cypher( "field", EqOperator, "value", param_store=param_store_empty ) - assert generated == "node.`field` = $param_0" + assert generated == "node.field = $param_0" assert param_store_empty.params == {"param_0": "value"} @@ -61,7 +81,7 @@ def test_single_condition_cypher_eq_node_alias(param_store_empty): generated = _single_condition_cypher( "field", EqOperator, "value", node_alias="n", param_store=param_store_empty ) - assert generated == "n.`field` = $param_0" + assert generated == "n.field = $param_0" assert param_store_empty.params == {"param_0": "value"} @@ -69,7 +89,7 @@ def test_single_condition_cypher_neq(param_store_empty): generated = _single_condition_cypher( "field", NeqOperator, "value", param_store=param_store_empty ) - assert generated == "node.`field` <> $param_0" + assert generated == "node.field <> $param_0" assert param_store_empty.params == {"param_0": "value"} @@ -77,7 +97,7 @@ def test_single_condition_cypher_lt(param_store_empty): generated = _single_condition_cypher( "field", LtOperator, 10, param_store=param_store_empty ) - assert generated == "node.`field` < $param_0" + assert generated == "node.field < $param_0" assert param_store_empty.params == {"param_0": 10} @@ -85,7 +105,7 @@ def test_single_condition_cypher_gt(param_store_empty): generated = _single_condition_cypher( "field", GtOperator, 10, param_store=param_store_empty ) - assert generated == "node.`field` > $param_0" + assert generated == "node.field > $param_0" assert param_store_empty.params == {"param_0": 10} @@ -93,7 +113,7 @@ def test_single_condition_cypher_lte(param_store_empty): generated = _single_condition_cypher( "field", LteOperator, 10, param_store=param_store_empty ) - assert generated == "node.`field` <= $param_0" + assert generated == "node.field <= $param_0" assert param_store_empty.params == {"param_0": 10} @@ -101,7 +121,7 @@ def test_single_condition_cypher_gte(param_store_empty): generated = _single_condition_cypher( "field", GteOperator, 10, param_store=param_store_empty ) - assert generated == "node.`field` >= $param_0" + assert generated == "node.field >= $param_0" assert param_store_empty.params == {"param_0": 10} @@ -109,7 +129,7 @@ def test_single_condition_cypher_in_int(param_store_empty): generated = _single_condition_cypher( "field", InOperator, [1, 2, 3], param_store=param_store_empty ) - assert generated == "node.`field` IN $param_0" + assert generated == "node.field IN $param_0" assert param_store_empty.params == {"param_0": [1, 2, 3]} @@ -117,7 +137,7 @@ def test_single_condition_cypher_in_str(param_store_empty): generated = _single_condition_cypher( "field", InOperator, ["a", "b", "c"], param_store=param_store_empty ) - assert generated == "node.`field` IN $param_0" + assert generated == "node.field IN $param_0" assert param_store_empty.params == {"param_0": ["a", "b", "c"]} @@ -138,7 +158,7 @@ def test_single_condition_cypher_nin(param_store_empty): generated = _single_condition_cypher( "field", NinOperator, ["a", "b", "c"], param_store=param_store_empty ) - assert generated == "node.`field` NOT IN $param_0" + assert generated == "node.field NOT IN $param_0" assert param_store_empty.params == {"param_0": ["a", "b", "c"]} @@ -146,7 +166,7 @@ def test_single_condition_cypher_like(param_store_empty): generated = _single_condition_cypher( "field", LikeOperator, "value", param_store=param_store_empty ) - assert generated == "node.`field` CONTAINS $param_0" + assert generated == "node.field CONTAINS $param_0" assert param_store_empty.params == {"param_0": "value"} @@ -154,7 +174,7 @@ def test_single_condition_cypher_ilike(param_store_empty): generated = _single_condition_cypher( "field", ILikeOperator, "My Value", param_store=param_store_empty ) - assert generated == "toLower(node.`field`) CONTAINS $param_0" + assert generated == "toLower(node.field) CONTAINS $param_0" assert param_store_empty.params == {"param_0": "my value"} @@ -212,7 +232,7 @@ def test_handle_field_filter_operator_between(param_store_empty): generated = _handle_field_filter( "field", value={"$between": [0, 1]}, param_store=param_store_empty ) - assert generated == "$param_0 <= node.`field` <= $param_1" + assert generated == "$param_0 <= node.field <= $param_1" assert param_store_empty.params == {"param_0": 0, "param_1": 1} @@ -419,112 +439,112 @@ def test_construct_metadata_filter_invalid_operator(param_store_empty): def test_get_metadata_filter_single_field_string(): filters = {"field": "string_value"} query, params = get_metadata_filter(filters) - assert query == "node.`field` = $param_0" + assert query == "node.field = $param_0" assert params == {"param_0": "string_value"} def test_get_metadata_filter_single_field_int(): filters = {"field": 28} query, params = get_metadata_filter(filters) - assert query == "node.`field` = $param_0" + assert query == "node.field = $param_0" assert params == {"param_0": 28} def test_get_metadata_filter_single_field_bool(): filters = {"field": False} query, params = get_metadata_filter(filters) - assert query == "node.`field` = $param_0" + assert query == "node.field = $param_0" assert params == {"param_0": False} def test_get_metadata_filter_explicit_eq_operator(): filters = {"field": {"$eq": "string_value"}} query, params = get_metadata_filter(filters) - assert query == "node.`field` = $param_0" + assert query == "node.field = $param_0" assert params == {"param_0": "string_value"} def test_get_metadata_filter_neq_operator(): filters = {"field": {"$ne": "string_value"}} query, params = get_metadata_filter(filters) - assert query == "node.`field` <> $param_0" + assert query == "node.field <> $param_0" assert params == {"param_0": "string_value"} def test_get_metadata_filter_lt_operator(): filters = {"field": {"$lt": 1}} query, params = get_metadata_filter(filters) - assert query == "node.`field` < $param_0" + assert query == "node.field < $param_0" assert params == {"param_0": 1} def test_get_metadata_filter_gt_operator(): filters = {"field": {"$gt": 1}} query, params = get_metadata_filter(filters) - assert query == "node.`field` > $param_0" + assert query == "node.field > $param_0" assert params == {"param_0": 1} def test_get_metadata_filter_lte_operator(): filters = {"field": {"$lte": 1}} query, params = get_metadata_filter(filters) - assert query == "node.`field` <= $param_0" + assert query == "node.field <= $param_0" assert params == {"param_0": 1} def test_get_metadata_filter_gte_operator(): filters = {"field": {"$gte": 1}} query, params = get_metadata_filter(filters) - assert query == "node.`field` >= $param_0" + assert query == "node.field >= $param_0" assert params == {"param_0": 1} def test_get_metadata_filter_in_operator(): filters = {"field": {"$in": ["a", "b"]}} query, params = get_metadata_filter(filters) - assert query == "node.`field` IN $param_0" + assert query == "node.field IN $param_0" assert params == {"param_0": ["a", "b"]} def test_get_metadata_filter_not_in_operator(): filters = {"field": {"$nin": ["a", "b"]}} query, params = get_metadata_filter(filters) - assert query == "node.`field` NOT IN $param_0" + assert query == "node.field NOT IN $param_0" assert params == {"param_0": ["a", "b"]} def test_get_metadata_filter_like_operator(): filters = {"field": {"$like": "some_value"}} query, params = get_metadata_filter(filters) - assert query == "node.`field` CONTAINS $param_0" + assert query == "node.field CONTAINS $param_0" assert params == {"param_0": "some_value"} def test_get_metadata_filter_ilike_operator(): filters = {"field": {"$ilike": "Some Value"}} query, params = get_metadata_filter(filters) - assert query == "toLower(node.`field`) CONTAINS $param_0" + assert query == "toLower(node.field) CONTAINS $param_0" assert params == {"param_0": "some value"} def test_get_metadata_filter_between_operator(): filters = {"field": {"$between": [0, 1]}} query, params = get_metadata_filter(filters) - assert query == "$param_0 <= node.`field` <= $param_1" + assert query == "$param_0 <= node.field <= $param_1" assert params == {"param_0": 0, "param_1": 1} def test_get_metadata_filter_implicit_and_condition(): filters = {"field_1": "string_value", "field_2": True} query, params = get_metadata_filter(filters) - assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert query == "(node.field_1 = $param_0) AND (node.field_2 = $param_1)" assert params == {"param_0": "string_value", "param_1": True} def test_get_metadata_filter_explicit_and_condition(): filters = {"$and": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) - assert query == "(node.`field_1` = $param_0) AND (node.`field_2` = $param_1)" + assert query == "(node.field_1 = $param_0) AND (node.field_2 = $param_1)" assert params == {"param_0": "string_value", "param_1": True} @@ -533,14 +553,14 @@ def test_get_metadata_filter_explicit_and_condition_with_operator(): "$and": [{"field_1": {"$ne": "string_value"}}, {"field_2": {"$in": [1, 2]}}] } query, params = get_metadata_filter(filters) - assert query == "(node.`field_1` <> $param_0) AND (node.`field_2` IN $param_1)" + assert query == "(node.field_1 <> $param_0) AND (node.field_2 IN $param_1)" assert params == {"param_0": "string_value", "param_1": [1, 2]} def test_get_metadata_filter_or_condition(): filters = {"$or": [{"field_1": "string_value"}, {"field_2": True}]} query, params = get_metadata_filter(filters) - assert query == "(node.`field_1` = $param_0) OR (node.`field_2` = $param_1)" + assert query == "(node.field_1 = $param_0) OR (node.field_2 = $param_1)" assert params == {"param_0": "string_value", "param_1": True} @@ -553,8 +573,8 @@ def test_get_metadata_filter_and_or_combined(): } query, params = get_metadata_filter(filters) assert query == ( - "((node.`field_1` = $param_0) OR (node.`field_2` = $param_1)) " - "AND (node.`field_3` = $param_2)" + "((node.field_1 = $param_0) OR (node.field_2 = $param_1)) " + "AND (node.field_3 = $param_2)" ) assert params == {"param_0": "string_value", "param_1": True, "param_2": 11} From 6379a848d91b3b1b2872583d37f38729c461e9aa Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 13 May 2024 15:28:43 +0200 Subject: [PATCH 35/38] ruff --- src/neo4j_genai/filters.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 56507afd..52c6dd60 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -283,9 +283,7 @@ def _handle_field_filter( low, high = filter_value param_name_low = param_store.add(low) param_name_high = param_store.add(high) - query_snippet = ( - f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.{Operator.safe_field_cypher(field)} <= ${param_name_high}" - ) + query_snippet = f"${param_name_low} <= {DEFAULT_NODE_ALIAS}.{Operator.safe_field_cypher(field)} <= ${param_name_high}" return query_snippet # all the other operators are handled through their own classes: native_op_class = COMPARISONS_TO_NATIVE[operator] From 37b75a0780584be3e5b746ffea043dffac6fbadf Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 14 May 2024 10:13:28 +0200 Subject: [PATCH 36/38] Replace field name validity check by regex (same than: https://github.com/neo4j/cypher-builder/blob/main/src/utils/escape.ts#L54-L63) --- src/neo4j_genai/filters.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 52c6dd60..6af73857 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any, Type from collections import Counter @@ -44,9 +45,11 @@ def safe_field_cypher(field_name: str) -> str: field_name (str): The initial unescaped field name Returns: - The field name potentially surrounded with backticks if needed + The field name potentially surrounded with backticks if needed, + ready to be inserted into a Cypher query. """ - if field_name.isidentifier(): + pattern = r'^[a-z_][0-9a-z_]*$' + if re.match(pattern, field_name, re.IGNORECASE): return field_name escaped_field = field_name.replace("`", "``") return f"`{escaped_field}`" From 235edaabb73f62d9da84a9d09881886131d8b0d5 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 14 May 2024 10:15:41 +0200 Subject: [PATCH 37/38] Update docstring --- src/neo4j_genai/filters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 6af73857..5a03be40 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -350,6 +350,9 @@ def get_metadata_filter( ) -> tuple[str, dict]: """Construct the cypher filter snippet based on a filter dict + Note: the _construct_metadata_filter function is not thread-safe because + of the ParameterStore object. + Args: filter (dict): The filters to be converted to Cypher node_alias (str): The alias of node the filters must be applied on From ff25b93be372ec51520b9ecdeb496c2f4ebb0468 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 14 May 2024 10:17:41 +0200 Subject: [PATCH 38/38] ruff --- src/neo4j_genai/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index 5a03be40..ec80ac8c 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -48,7 +48,7 @@ def safe_field_cypher(field_name: str) -> str: The field name potentially surrounded with backticks if needed, ready to be inserted into a Cypher query. """ - pattern = r'^[a-z_][0-9a-z_]*$' + pattern = r"^[a-z_][0-9a-z_]*$" if re.match(pattern, field_name, re.IGNORECASE): return field_name escaped_field = field_name.replace("`", "``")