Skip to content

Commit

Permalink
Fix bugs and add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Amna Mubashar authored and Amna Mubashar committed Jul 9, 2024
1 parent ceefc4d commit d0cc61a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
from qdrant_client.http import models
Expand All @@ -10,8 +10,18 @@

def convert_filters_to_qdrant(
filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True
) -> Optional[Union[models.Filter, List[models.Filter]]]:
"""Converts Haystack filters to the format used by Qdrant."""
) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]:
"""Converts Haystack filters to the format used by Qdrant.
:param filter_term: the haystack filter to be converted to qdrant.
:param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns
a single models.Filter object; if False, it may return a list of filters or conditions for further processing.
:returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls.
:raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered.
"""

if isinstance(filter_term, models.Filter):
return filter_term
Expand All @@ -21,7 +31,9 @@ def convert_filters_to_qdrant(
must_clauses: List[models.Filter] = []
should_clauses: List[models.Filter] = []
must_not_clauses: List[models.Filter] = []
same_operator_flag = False # For same operators on each level, we need nested clauses
# Indicates if there are multiple same LOGICAL OPERATORS on each level
# and prevents them from being combined
same_operator_flag = False
conditions, qdrant_filter, current_level_operators = (
[],
[],
Expand All @@ -31,10 +43,12 @@ def convert_filters_to_qdrant(
if isinstance(filter_term, dict):
filter_term = [filter_term]

# ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ========

for item in filter_term:
operator = item.get("operator")

# Check for same operators on each level
# Check for repeated similar operators on each level
same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS
if not same_operator_flag:
current_level_operators.append(operator)
Expand All @@ -51,7 +65,8 @@ def convert_filters_to_qdrant(
# Recursively process nested conditions
current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or []

# Append or nest clauses based on same_operator_flag
# When same_operator_flag is set to True,
# ensure each clause is appended as an independent list to avoid merging distinct clauses.
if operator == "AND":
must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter
elif operator == "OR":
Expand Down Expand Up @@ -83,64 +98,93 @@ def convert_filters_to_qdrant(
msg = f"Unknown operator {operator} used in filters"
raise FilterError(msg)

# Handle same operators on each level by building nested payloads
# ======== PROCESS FILTER ITEMS ON EACH LEVEL ========

# If same logical operators have separate clauses, create separate filters
if same_operator_flag:
qdrant_filter = build_payload_for_same_operators(must_clauses, should_clauses, must_not_clauses, qdrant_filter)
if not is_parent_call:
return qdrant_filter
# Append built payload if any clauses are present
qdrant_filter = build_filters_for_repeated_operators(
must_clauses, should_clauses, must_not_clauses, qdrant_filter
)

# else append a single Filter for existing clauses
elif must_clauses or should_clauses or must_not_clauses:
qdrant_filter.append(build_payload(must_clauses, should_clauses, must_not_clauses))
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)
)

# Handle the parent call case to ensure a single Filter is returned
# In case of parent call, a single Filter is returned
if is_parent_call:
# If qdrant_filter has just a single Filter in parent call,
# then it might be returned instead.
if len(qdrant_filter) == 1 and isinstance(qdrant_filter[0], models.Filter):
return qdrant_filter[0]
else:
must_clauses.extend(conditions)
return build_payload(must_clauses, should_clauses, must_not_clauses)
return models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)

# Store conditions of each level in output of the loop
if conditions:
elif conditions:
qdrant_filter.extend(conditions)

return qdrant_filter


def build_payload(
must_clauses: List[models.Condition],
should_clauses: List[models.Condition],
must_not_clauses: List[models.Condition],
) -> models.Filter:

return models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)


def build_payload_for_same_operators(
must_clauses: List[models.Condition],
should_clauses: List[models.Condition],
must_not_clauses: List[models.Condition],
output_filter: List[Any],
def build_filters_for_repeated_operators(
must_clauses,
should_clauses,
must_not_clauses,
qdrant_filter,
) -> List[models.Filter]:
"""
Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator.
:param must_clauses: a nested list of must clauses or an empty list.
:param should_clauses: a nested list of should clauses or an empty list.
:param must_not_clauses: a nested list of must_not clauses or an empty list.
:param qdrant_filter: a list where the generated Filter objects will be appended.
This list will be modified in-place.
:returns: the modified `qdrant_filter` list with appended generated Filter objects.
"""

if any(isinstance(i, list) for i in must_clauses):
for i in must_clauses:
qdrant_filter.append(
models.Filter(
must=i or None,
should=should_clauses or None,
must_not=must_not_clauses or None,
)
)
if any(isinstance(i, list) for i in should_clauses):
for i in should_clauses:
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=i or None,
must_not=must_not_clauses or None,
)
)
if any(isinstance(i, list) for i in must_not_clauses):
for i in must_clauses:
qdrant_filter.append(
models.Filter(
must=must_clauses or None,
should=should_clauses or None,
must_not=i or None,
)
)

clause_types = [
(must_clauses, should_clauses, must_not_clauses),
(should_clauses, must_clauses, must_not_clauses),
(must_not_clauses, must_clauses, should_clauses),
]

for clauses, arg1, arg2 in clause_types:
if any(isinstance(i, list) for i in clauses):
for clause in clauses:
output_filter.append(build_payload(clause, arg1, arg2))

return output_filter
return qdrant_filter


def _parse_comparison_operation(
Expand Down Expand Up @@ -197,7 +241,7 @@ def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Conditi
must_not=[
(
models.FieldCondition(key=key, match=models.MatchText(text=value))
if isinstance(value, str) and " " in value
if isinstance(value, str) and " " not in value
else models.FieldCondition(key=key, match=models.MatchValue(value=value))
)
]
Expand Down
77 changes: 7 additions & 70 deletions integrations/qdrant/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def test_not_operator(self, document_store, filterable_docs):
def test_filter_criteria(self, document_store):
documents = [
Document(
content="This is a test document 1.",
content="This is test document 1.",
meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}},
),
Document(
content="This is a test document 2.",
content="This is test document 2.",
meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}},
),
Document(
content="This is a test document 3.",
content="This is test document 3.",
meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}},
),
]
Expand Down Expand Up @@ -105,87 +105,24 @@ def test_filter_criteria(self, document_store):
],
)

def test_advanced_filter_criteria(self, document_store):
def test_complex_filter_criteria(self, document_store):
documents = [
Document(
content="This is a test document 1.",
meta={"file_name": "file1", "classification": {"details": {"category1": 0.8, "category3": 0.2}}},
),
Document(
content="This is a test document 2.",
meta={"file_name": "file2", "classification": {"details": {"category1": 0.3, "category3": 0.95}}},
),
Document(
content="This is a test document 3.",
meta={"file_name": "file3", "classification": {"details": {"category2": 0.6, "category3": 0.85}}},
),
Document(
content="This is a test document 4.",
meta={"file_name": "file4", "classification": {"details": {"category2": 0.88, "category3": 0.4}}},
),
]

document_store.write_documents(documents)
filter_criteria = {
"operator": "AND",
"conditions": [
{
"operator": "OR",
"conditions": [
{"field": "meta.file_name", "operator": "in", "value": ["file1", "file3"]},
{"field": "meta.file_name", "operator": "in", "value": ["file4"]},
],
},
{
"operator": "AND",
"conditions": [
{
"operator": "OR",
"conditions": [
{"field": "meta.classification.details.category1", "operator": ">=", "value": 0.75},
{"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85},
],
},
{"field": "meta.classification.details.category3", "operator": "<", "value": 0.9},
],
},
],
}
result = document_store.filter_documents(filter_criteria)
self.assert_documents_are_equal(
result,
[
d
for d in documents
if (
(d.meta.get("file_name") in ["file1", "file3"] or d.meta.get("file_name") == "file4")
and (
(d.meta.get("classification").get("details").get("category1", 0) >= 0.75)
or (d.meta.get("classification").get("details").get("category2", 0) >= 0.85)
)
and (d.meta.get("classification").get("details").get("category3") < 0.9)
)
],
)

def test_filter_criteria_complex(self, document_store):
documents = [
Document(
content="Complex document 1.",
content="This is test document 1.",
meta={
"file_name": "file1",
"classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}},
},
),
Document(
content="Complex document 2.",
content="This is test document 2.",
meta={
"file_name": "file2",
"classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}},
},
),
Document(
content="Complex document 3.",
content="This is test document 3.",
meta={
"file_name": "file3",
"classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}},
Expand Down

0 comments on commit d0cc61a

Please sign in to comment.