Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: use one-class pattern for SparseArray #56513

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pandas/tests/extension/base/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def skip_if_immutable(self, dtype, request):
# This fixture is auto-used, but we want to not-skip
# test_is_immutable.
return
pytest.skip(f"__setitem__ test not applicable with immutable dtype {dtype}")

# When BaseSetitemTests is mixed into ExtensionTests, we only
# want this fixture to operate on the tests defined in this
# class/file.
defined_in = node.function.__qualname__.split(".")[0]
if defined_in == "BaseSetitemTests":
pytest.skip("__setitem__ test not applicable with immutable dtype")

def test_is_immutable(self, data):
if data.dtype._is_immutable:
Expand Down
163 changes: 104 additions & 59 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,64 @@ def data_for_compare(request):
return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param)


class BaseSparseTests:
class TestSparseArray(base.ExtensionTests):
def _supports_reduction(self, obj, op_name: str) -> bool:
return True

@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
if all_numeric_reductions in [
"prod",
"median",
"var",
"std",
"sem",
"skew",
"kurt",
]:
mark = pytest.mark.xfail(
reason="This should be viable but is not implemented"
)
request.node.add_marker(mark)
elif (
all_numeric_reductions in ["sum", "max", "min", "mean"]
and data.dtype.kind == "f"
and not skipna
):
mark = pytest.mark.xfail(reason="getting a non-nan float")
request.node.add_marker(mark)

super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
if all_numeric_reductions in [
"prod",
"median",
"var",
"std",
"sem",
"skew",
"kurt",
]:
mark = pytest.mark.xfail(
reason="This should be viable but is not implemented"
)
request.node.add_marker(mark)
elif (
all_numeric_reductions in ["sum", "max", "min", "mean"]
and data.dtype.kind == "f"
and not skipna
):
mark = pytest.mark.xfail(reason="ExtensionArray NA mask are different")
request.node.add_marker(mark)

super().test_reduce_frame(data, all_numeric_reductions, skipna)

def _check_unsupported(self, data):
if data.dtype == SparseDtype(int, 0):
pytest.skip("Can't store nan in int array.")


class TestDtype(BaseSparseTests, base.BaseDtypeTests):
def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is SparseArray


class TestInterface(BaseSparseTests, base.BaseInterfaceTests):
pass


class TestConstructors(BaseSparseTests, base.BaseConstructorsTests):
pass


class TestReshaping(BaseSparseTests, base.BaseReshapingTests):
def test_concat_mixed_dtypes(self, data):
# https://github.com/pandas-dev/pandas/issues/20762
# This should be the same, aside from concat([sparse, float])
Expand Down Expand Up @@ -173,8 +211,6 @@ def test_merge(self, data, na_value):
self._check_unsupported(data)
super().test_merge(data, na_value)


class TestGetitem(BaseSparseTests, base.BaseGetitemTests):
def test_get(self, data):
ser = pd.Series(data, index=[2 * i for i in range(len(data))])
if np.isnan(ser.values.fill_value):
Expand All @@ -187,16 +223,6 @@ def test_reindex(self, data, na_value):
self._check_unsupported(data)
super().test_reindex(data, na_value)


class TestSetitem(BaseSparseTests, base.BaseSetitemTests):
pass


class TestIndex(base.BaseIndexTests):
pass


class TestMissing(BaseSparseTests, base.BaseMissingTests):
def test_isna(self, data_missing):
sarr = SparseArray(data_missing)
expected_dtype = SparseDtype(bool, pd.isna(data_missing.dtype.fill_value))
Expand Down Expand Up @@ -249,8 +275,6 @@ def test_fillna_frame(self, data_missing):

tm.assert_frame_equal(result, expected)


class TestMethods(BaseSparseTests, base.BaseMethodsTests):
_combine_le_expected_dtype = "Sparse[bool]"

def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
Expand Down Expand Up @@ -351,16 +375,12 @@ def test_map_raises(self, data, na_action):
with pytest.raises(ValueError, match=msg):
data.map(lambda x: np.nan, na_action=na_action)


class TestCasting(BaseSparseTests, base.BaseCastingTests):
@pytest.mark.xfail(raises=TypeError, reason="no sparse StringDtype")
def test_astype_string(self, data, nullable_string_dtype):
# TODO: this fails bc we do not pass through nullable_string_dtype;
# If we did, the 0-cases would xpass
super().test_astype_string(data)


class TestArithmeticOps(BaseSparseTests, base.BaseArithmeticOpsTests):
series_scalar_exc = None
frame_scalar_exc = None
divmod_exc = None
Expand Down Expand Up @@ -397,17 +417,27 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
request.applymarker(mark)
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)


class TestComparisonOps(BaseSparseTests):
def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
def _compare_other(
self, ser: pd.Series, data_for_compare: SparseArray, comparison_op, other
):
op = comparison_op

result = op(data_for_compare, other)
assert isinstance(result, SparseArray)
if isinstance(other, pd.Series):
assert isinstance(result, pd.Series)
assert isinstance(result.dtype, SparseDtype)
else:
assert isinstance(result, SparseArray)
assert result.dtype.subtype == np.bool_

if isinstance(other, SparseArray):
fill_value = op(data_for_compare.fill_value, other.fill_value)
if isinstance(other, pd.Series):
fill_value = op(data_for_compare.fill_value, other._values.fill_value)
expected = SparseArray(
op(data_for_compare.to_dense(), np.asarray(other)),
fill_value=fill_value,
dtype=np.bool_,
)

else:
fill_value = np.all(
op(np.asarray(data_for_compare.fill_value), np.asarray(other))
Expand All @@ -418,36 +448,51 @@ def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
fill_value=fill_value,
dtype=np.bool_,
)
tm.assert_sp_array_equal(result, expected)
if isinstance(other, pd.Series):
# error: Incompatible types in assignment
expected = pd.Series(expected) # type: ignore[assignment]
tm.assert_equal(result, expected)

def test_scalar(self, data_for_compare: SparseArray, comparison_op):
self._compare_other(data_for_compare, comparison_op, 0)
self._compare_other(data_for_compare, comparison_op, 1)
self._compare_other(data_for_compare, comparison_op, -1)
self._compare_other(data_for_compare, comparison_op, np.nan)
ser = pd.Series(data_for_compare)
self._compare_other(ser, data_for_compare, comparison_op, 0)
self._compare_other(ser, data_for_compare, comparison_op, 1)
self._compare_other(ser, data_for_compare, comparison_op, -1)
self._compare_other(ser, data_for_compare, comparison_op, np.nan)

def test_array(self, data_for_compare: SparseArray, comparison_op, request):
if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ in [
"eq",
"ge",
"le",
]:
mark = pytest.mark.xfail(reason="Wrong fill_value")
request.applymarker(mark)

@pytest.mark.xfail(reason="Wrong indices")
def test_array(self, data_for_compare: SparseArray, comparison_op):
arr = np.linspace(-4, 5, 10)
self._compare_other(data_for_compare, comparison_op, arr)
ser = pd.Series(data_for_compare)
self._compare_other(ser, data_for_compare, comparison_op, arr)

@pytest.mark.xfail(reason="Wrong indices")
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op):
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op, request):
if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ != "gt":
mark = pytest.mark.xfail(reason="Wrong fill_value")
request.applymarker(mark)

ser = pd.Series(data_for_compare)
arr = data_for_compare + 1
self._compare_other(data_for_compare, comparison_op, arr)
self._compare_other(ser, data_for_compare, comparison_op, arr)
arr = data_for_compare * 2
self._compare_other(data_for_compare, comparison_op, arr)
self._compare_other(ser, data_for_compare, comparison_op, arr)


class TestPrinting(BaseSparseTests, base.BasePrintingTests):
@pytest.mark.xfail(reason="Different repr")
def test_array_repr(self, data, size):
super().test_array_repr(data, size)


class TestParsing(BaseSparseTests, base.BaseParsingTests):
pass
@pytest.mark.xfail(reason="result does not match expected")
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
super().test_groupby_extension_agg(as_index, data_for_grouping)


class TestNoNumericAccumulations(base.BaseAccumulateTests):
pass
def test_array_type_with_arg(dtype):
assert dtype.construct_array_type() is SparseArray