diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 3509bbf17..70460b733 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -67,10 +67,12 @@ _BAD_DIR_STRING: str -_BAD_OP_NAN_NULL: str +_BAD_OP_NAN: str +_BAD_OP_NULL: str _BAD_OP_STRING: str _COMPARISON_OPERATORS: Dict[str, Any] _EQ_OP: str +_NEQ_OP: str _INVALID_CURSOR_TRANSFORM: str _INVALID_WHERE_TRANSFORM: str _MISMATCH_CURSOR_W_ORDER_BY: str @@ -80,12 +82,13 @@ _EQ_OP = "==" +_NEQ_OP = "!=" _operator_enum = StructuredQuery.FieldFilter.Operator _COMPARISON_OPERATORS = { "<": _operator_enum.LESS_THAN, "<=": _operator_enum.LESS_THAN_OR_EQUAL, _EQ_OP: _operator_enum.EQUAL, - "!=": _operator_enum.NOT_EQUAL, + _NEQ_OP: _operator_enum.NOT_EQUAL, ">=": _operator_enum.GREATER_THAN_OR_EQUAL, ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, @@ -104,7 +107,8 @@ _operator_enum.NOT_IN, ) _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." -_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' +_BAD_OP_NAN = 'Only an equality filter ("==") can be used with NaN values' +_BAD_OP_NULL = 'Only equality ("==") or not-equal ("!=") filters can be used with None values' _INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." _BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." _INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." @@ -144,13 +148,16 @@ def __init__(self, field_path, op_string, value=None): self.value = value if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL + if op_string == _EQ_OP: + self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NULL + elif op_string == _NEQ_OP: + self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL + else: + raise ValueError(_BAD_OP_NULL) elif _isnan(value): if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) + raise ValueError(_BAD_OP_NAN) self.op_string = StructuredQuery.UnaryFilter.Operator.IS_NAN elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): raise ValueError(_INVALID_WHERE_TRANSFORM) @@ -479,15 +486,20 @@ def where( stacklevel=2, ) if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) + if op_string == _EQ_OP: + op = StructuredQuery.UnaryFilter.Operator.IS_NULL + elif op_string == _NEQ_OP: + op = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL + else: + raise ValueError(_BAD_OP_NULL) + filter_pb = query.StructuredQuery.UnaryFilter( field=query.StructuredQuery.FieldReference(field_path=field_path), - op=StructuredQuery.UnaryFilter.Operator.IS_NULL, + op=op ) elif _isnan(value): if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) + raise ValueError(_BAD_OP_NAN) filter_pb = query.StructuredQuery.UnaryFilter( field=query.StructuredQuery.FieldReference(field_path=field_path), op=StructuredQuery.UnaryFilter.Operator.IS_NAN, diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 24caa5e40..62cd7457f 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -301,6 +301,20 @@ def test_basequery_where_eq_null(unary_helper_function): unary_helper_function(None, op_enum) +@pytest.mark.parametrize( + "unary_helper_function", + [ + (_where_unary_helper), + (_where_unary_helper_field_filter), + ], +) +def test_basequery_where_neq_null(unary_helper_function): + from google.cloud.firestore_v1.types import StructuredQuery + + op_enum = StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL + unary_helper_function(None, op_enum, op_string="!=") + + @pytest.mark.parametrize( "unary_helper_function", [ @@ -309,11 +323,11 @@ def test_basequery_where_eq_null(unary_helper_function): ], ) def test_basequery_where_gt_null(unary_helper_function): - from google.cloud.firestore_v1.base_query import _BAD_OP_NAN_NULL + from google.cloud.firestore_v1.base_query import _BAD_OP_NULL with pytest.raises(ValueError) as exc: unary_helper_function(None, 0, op_string=">") - assert str(exc.value) == _BAD_OP_NAN_NULL + assert str(exc.value) == _BAD_OP_NULL @pytest.mark.parametrize( @@ -338,11 +352,11 @@ def test_basequery_where_eq_nan(unary_helper_function): ], ) def test_basequery_where_le_nan(unary_helper_function): - from google.cloud.firestore_v1.base_query import _BAD_OP_NAN_NULL + from google.cloud.firestore_v1.base_query import _BAD_OP_NAN with pytest.raises(ValueError) as exc: unary_helper_function(float("nan"), 0, op_string="<=") - assert str(exc.value) == _BAD_OP_NAN_NULL + assert str(exc.value) == _BAD_OP_NAN @pytest.mark.parametrize(