diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index 6231a5db4..69fd7cbbd 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, List, Optional, Union +from typing import List, Optional, Union from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models @@ -10,8 +10,18 @@ def convert_filters_to_qdrant( filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True -) -> Optional[Union[models.Filter, List[models.Filter]]]: - """Converts Haystack filters to the format used by Qdrant.""" +) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]: + """Converts Haystack filters to the format used by Qdrant. + + :param filter_term: the haystack filter to be converted to qdrant. + :param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns + a single models.Filter object; if False, it may return a list of filters or conditions for further processing. + + :returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls. + + :raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered. + + """ if isinstance(filter_term, models.Filter): return filter_term @@ -21,7 +31,9 @@ def convert_filters_to_qdrant( must_clauses: List[models.Filter] = [] should_clauses: List[models.Filter] = [] must_not_clauses: List[models.Filter] = [] - same_operator_flag = False # For same operators on each level, we need nested clauses + # Indicates if there are multiple same LOGICAL OPERATORS on each level + # and prevents them from being combined + same_operator_flag = False conditions, qdrant_filter, current_level_operators = ( [], [], @@ -31,10 +43,12 @@ def convert_filters_to_qdrant( if isinstance(filter_term, dict): filter_term = [filter_term] + # ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ======== + for item in filter_term: operator = item.get("operator") - # Check for same operators on each level + # Check for repeated similar operators on each level same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS if not same_operator_flag: current_level_operators.append(operator) @@ -51,7 +65,8 @@ def convert_filters_to_qdrant( # Recursively process nested conditions current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or [] - # Append or nest clauses based on same_operator_flag + # When same_operator_flag is set to True, + # ensure each clause is appended as an independent list to avoid merging distinct clauses. if operator == "AND": must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter elif operator == "OR": @@ -83,16 +98,25 @@ def convert_filters_to_qdrant( msg = f"Unknown operator {operator} used in filters" raise FilterError(msg) - # Handle same operators on each level by building nested payloads + # ======== PROCESS FILTER ITEMS ON EACH LEVEL ======== + + # If same logical operators have separate clauses, create separate filters if same_operator_flag: - qdrant_filter = build_payload_for_same_operators(must_clauses, should_clauses, must_not_clauses, qdrant_filter) - if not is_parent_call: - return qdrant_filter - # Append built payload if any clauses are present + qdrant_filter = build_filters_for_repeated_operators( + must_clauses, should_clauses, must_not_clauses, qdrant_filter + ) + + # else append a single Filter for existing clauses elif must_clauses or should_clauses or must_not_clauses: - qdrant_filter.append(build_payload(must_clauses, should_clauses, must_not_clauses)) + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) - # Handle the parent call case to ensure a single Filter is returned + # In case of parent call, a single Filter is returned if is_parent_call: # If qdrant_filter has just a single Filter in parent call, # then it might be returned instead. @@ -100,47 +124,67 @@ def convert_filters_to_qdrant( return qdrant_filter[0] else: must_clauses.extend(conditions) - return build_payload(must_clauses, should_clauses, must_not_clauses) + return models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) # Store conditions of each level in output of the loop - if conditions: + elif conditions: qdrant_filter.extend(conditions) return qdrant_filter -def build_payload( - must_clauses: List[models.Condition], - should_clauses: List[models.Condition], - must_not_clauses: List[models.Condition], -) -> models.Filter: - - return models.Filter( - must=must_clauses or None, - should=should_clauses or None, - must_not=must_not_clauses or None, - ) - - -def build_payload_for_same_operators( - must_clauses: List[models.Condition], - should_clauses: List[models.Condition], - must_not_clauses: List[models.Condition], - output_filter: List[Any], +def build_filters_for_repeated_operators( + must_clauses, + should_clauses, + must_not_clauses, + qdrant_filter, ) -> List[models.Filter]: + """ + Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator. + + :param must_clauses: a nested list of must clauses or an empty list. + :param should_clauses: a nested list of should clauses or an empty list. + :param must_not_clauses: a nested list of must_not clauses or an empty list. + :param qdrant_filter: a list where the generated Filter objects will be appended. + This list will be modified in-place. + + + :returns: the modified `qdrant_filter` list with appended generated Filter objects. + """ + + if any(isinstance(i, list) for i in must_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=i or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in should_clauses): + for i in should_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=i or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in must_not_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=i or None, + ) + ) - clause_types = [ - (must_clauses, should_clauses, must_not_clauses), - (should_clauses, must_clauses, must_not_clauses), - (must_not_clauses, must_clauses, should_clauses), - ] - - for clauses, arg1, arg2 in clause_types: - if any(isinstance(i, list) for i in clauses): - for clause in clauses: - output_filter.append(build_payload(clause, arg1, arg2)) - - return output_filter + return qdrant_filter def _parse_comparison_operation( @@ -197,7 +241,7 @@ def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Conditi must_not=[ ( models.FieldCondition(key=key, match=models.MatchText(text=value)) - if isinstance(value, str) and " " in value + if isinstance(value, str) and " " not in value else models.FieldCondition(key=key, match=models.MatchValue(value=value)) ) ] diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 61b7dbcea..fd070bda9 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -64,15 +64,15 @@ def test_not_operator(self, document_store, filterable_docs): def test_filter_criteria(self, document_store): documents = [ Document( - content="This is a test document 1.", + content="This is test document 1.", meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}}, ), Document( - content="This is a test document 2.", + content="This is test document 2.", meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}}, ), Document( - content="This is a test document 3.", + content="This is test document 3.", meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}}, ), ] @@ -105,87 +105,24 @@ def test_filter_criteria(self, document_store): ], ) - def test_advanced_filter_criteria(self, document_store): + def test_complex_filter_criteria(self, document_store): documents = [ Document( - content="This is a test document 1.", - meta={"file_name": "file1", "classification": {"details": {"category1": 0.8, "category3": 0.2}}}, - ), - Document( - content="This is a test document 2.", - meta={"file_name": "file2", "classification": {"details": {"category1": 0.3, "category3": 0.95}}}, - ), - Document( - content="This is a test document 3.", - meta={"file_name": "file3", "classification": {"details": {"category2": 0.6, "category3": 0.85}}}, - ), - Document( - content="This is a test document 4.", - meta={"file_name": "file4", "classification": {"details": {"category2": 0.88, "category3": 0.4}}}, - ), - ] - - document_store.write_documents(documents) - filter_criteria = { - "operator": "AND", - "conditions": [ - { - "operator": "OR", - "conditions": [ - {"field": "meta.file_name", "operator": "in", "value": ["file1", "file3"]}, - {"field": "meta.file_name", "operator": "in", "value": ["file4"]}, - ], - }, - { - "operator": "AND", - "conditions": [ - { - "operator": "OR", - "conditions": [ - {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.75}, - {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85}, - ], - }, - {"field": "meta.classification.details.category3", "operator": "<", "value": 0.9}, - ], - }, - ], - } - result = document_store.filter_documents(filter_criteria) - self.assert_documents_are_equal( - result, - [ - d - for d in documents - if ( - (d.meta.get("file_name") in ["file1", "file3"] or d.meta.get("file_name") == "file4") - and ( - (d.meta.get("classification").get("details").get("category1", 0) >= 0.75) - or (d.meta.get("classification").get("details").get("category2", 0) >= 0.85) - ) - and (d.meta.get("classification").get("details").get("category3") < 0.9) - ) - ], - ) - - def test_filter_criteria_complex(self, document_store): - documents = [ - Document( - content="Complex document 1.", + content="This is test document 1.", meta={ "file_name": "file1", "classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}}, }, ), Document( - content="Complex document 2.", + content="This is test document 2.", meta={ "file_name": "file2", "classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}}, }, ), Document( - content="Complex document 3.", + content="This is test document 3.", meta={ "file_name": "file3", "classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}},