Skip to content

Commit

Permalink
Type hints in pgvector document store updated for 3.8 compability (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezhvsalate authored May 6, 2024
1 parent 04fb950 commit 9659b13
Showing 1 changed file with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple

from haystack.errors import FilterError
from pandas import DataFrame
Expand All @@ -22,7 +22,7 @@
NO_VALUE = "no_value"


def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]:
def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> Tuple[SQL, Tuple]:
"""
Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL.
"""
Expand All @@ -37,7 +37,7 @@ def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tupl
return where_clause, params


def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]:
def _parse_logical_condition(condition: Dict[str, Any]) -> Tuple[str, List[Any]]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]
return sql_query, values


def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]:
def _parse_comparison_condition(condition: Dict[str, Any]) -> Tuple[str, List[Any]]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
Expand Down Expand Up @@ -132,20 +132,20 @@ def _treat_meta_field(field: str, value: Any) -> str:
return field


def _equal(field: str, value: Any) -> tuple[str, Any]:
def _equal(field: str, value: Any) -> Tuple[str, Any]:
if value is None:
# NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
return f"{field} IS NULL", NO_VALUE
return f"{field} = %s", value


def _not_equal(field: str, value: Any) -> tuple[str, Any]:
def _not_equal(field: str, value: Any) -> Tuple[str, Any]:
# we use IS DISTINCT FROM to correctly handle NULL values
# (not handled by !=)
return f"{field} IS DISTINCT FROM %s", value


def _greater_than(field: str, value: Any) -> tuple[str, Any]:
def _greater_than(field: str, value: Any) -> Tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -162,7 +162,7 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]:
return f"{field} > %s", value


def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _greater_than_equal(field: str, value: Any) -> Tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -179,7 +179,7 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
return f"{field} >= %s", value


def _less_than(field: str, value: Any) -> tuple[str, Any]:
def _less_than(field: str, value: Any) -> Tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -196,7 +196,7 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]:
return f"{field} < %s", value


def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
def _less_than_equal(field: str, value: Any) -> Tuple[str, Any]:
if isinstance(value, str):
try:
datetime.fromisoformat(value)
Expand All @@ -213,15 +213,15 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
return f"{field} <= %s", value


def _not_in(field: str, value: Any) -> tuple[str, List]:
def _not_in(field: str, value: Any) -> Tuple[str, List]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return f"{field} IS NULL OR {field} != ALL(%s)", [value]


def _in(field: str, value: Any) -> tuple[str, List]:
def _in(field: str, value: Any) -> Tuple[str, List]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)
Expand Down

0 comments on commit 9659b13

Please sign in to comment.