Skip to content

Commit

Permalink
TST: use one-class pattern in test_numpy (#56512)
Browse files Browse the repository at this point in the history
* TST: use one-class pattern in test_numpy

* revert accidentally-commited
  • Loading branch information
jbrockmendel authored Dec 15, 2023
1 parent ee4ceec commit bb14870
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 65 deletions.
4 changes: 4 additions & 0 deletions pandas/tests/extension/base/interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.common import is_extension_array_dtype
from pandas.core.dtypes.dtypes import ExtensionDtype

Expand Down Expand Up @@ -65,6 +66,9 @@ def test_array_interface(self, data):

result = np.array(data, dtype=object)
expected = np.array(list(data), dtype=object)
if expected.ndim > 1:
# nested data, explicitly construct as 1D
expected = construct_1d_object_array_from_listlike(list(data))
tm.assert_numpy_array_equal(result, expected)

def test_is_extension_array_dtype(self, data):
Expand Down
125 changes: 60 additions & 65 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.dtypes import NumpyEADtype

import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_object_dtype
from pandas.core.arrays.numpy_ import NumpyExtensionArray
from pandas.core.internals import blocks
from pandas.tests.extension import base


def _can_hold_element_patched(obj, element) -> bool:
if isinstance(element, NumpyExtensionArray):
element = element.to_numpy()
return can_hold_element(obj, element)


orig_assert_attr_equal = tm.assert_attr_equal


Expand Down Expand Up @@ -78,7 +69,6 @@ def allow_in_pandas(monkeypatch):
"""
with monkeypatch.context() as m:
m.setattr(NumpyExtensionArray, "_typ", "extension")
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
m.setattr(tm.asserters, "assert_attr_equal", _assert_attr_equal)
yield

Expand Down Expand Up @@ -175,15 +165,7 @@ def skip_numpy_object(dtype, request):
skip_nested = pytest.mark.usefixtures("skip_numpy_object")


class BaseNumPyTests:
pass


class TestCasting(BaseNumPyTests, base.BaseCastingTests):
pass


class TestConstructors(BaseNumPyTests, base.BaseConstructorsTests):
class TestNumpyExtensionArray(base.ExtensionTests):
@pytest.mark.skip(reason="We don't register our dtype")
# We don't want to register. This test should probably be split in two.
def test_from_dtype(self, data):
Expand All @@ -194,8 +176,6 @@ def test_series_constructor_scalar_with_index(self, data, dtype):
# ValueError: Length of passed values is 1, index implies 3.
super().test_series_constructor_scalar_with_index(data, dtype)


class TestDtype(BaseNumPyTests, base.BaseDtypeTests):
def test_check_dtype(self, data, request, using_infer_string):
if data.dtype.numpy_dtype == "object":
request.applymarker(
Expand All @@ -214,26 +194,11 @@ def test_is_not_object_type(self, dtype, request):
else:
super().test_is_not_object_type(dtype)


class TestGetitem(BaseNumPyTests, base.BaseGetitemTests):
@skip_nested
def test_getitem_scalar(self, data):
# AssertionError
super().test_getitem_scalar(data)


class TestGroupby(BaseNumPyTests, base.BaseGroupbyTests):
pass


class TestInterface(BaseNumPyTests, base.BaseInterfaceTests):
@skip_nested
def test_array_interface(self, data):
# NumPy array shape inference
super().test_array_interface(data)


class TestMethods(BaseNumPyTests, base.BaseMethodsTests):
@skip_nested
def test_shift_fill_value(self, data):
# np.array shape inference. Shift implementation fails.
Expand All @@ -251,7 +216,9 @@ def test_fillna_copy_series(self, data_missing):

@skip_nested
def test_searchsorted(self, data_for_sorting, as_series):
# Test setup fails.
# TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which
# isn't quite what we want in nested data cases. Instead we need to
# adapt something like libindex._bin_search.
super().test_searchsorted(data_for_sorting, as_series)

@pytest.mark.xfail(reason="NumpyExtensionArray.diff may fail on dtype")
Expand All @@ -270,38 +237,60 @@ def test_insert_invalid(self, data, invalid_scalar):
# NumpyExtensionArray[object] can hold anything, so skip
super().test_insert_invalid(data, invalid_scalar)


class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
divmod_exc = None
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

@skip_nested
def test_divmod(self, data):
divmod_exc = None
if data.dtype.kind == "O":
divmod_exc = TypeError
self.divmod_exc = divmod_exc
super().test_divmod(data)

@skip_nested
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
def test_divmod_series_array(self, data):
ser = pd.Series(data)
exc = None
if data.dtype.kind == "O":
exc = TypeError
self.divmod_exc = exc
self._check_divmod_op(ser, divmod, data)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
opname = all_arithmetic_operators
series_scalar_exc = None
if data.dtype.numpy_dtype == object:
if opname in ["__mul__", "__rmul__"]:
mark = pytest.mark.xfail(
reason="the Series.combine step raises but not the Series method."
)
request.node.add_marker(mark)
series_scalar_exc = TypeError
self.series_scalar_exc = series_scalar_exc
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
def test_arith_series_with_array(self, data, all_arithmetic_operators):
opname = all_arithmetic_operators
series_array_exc = None
if data.dtype.numpy_dtype == object and opname not in ["__add__", "__radd__"]:
mark = pytest.mark.xfail(reason="Fails for object dtype")
request.applymarker(mark)
series_array_exc = TypeError
self.series_array_exc = series_array_exc
super().test_arith_series_with_array(data, all_arithmetic_operators)

@skip_nested
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
opname = all_arithmetic_operators
frame_scalar_exc = None
if data.dtype.numpy_dtype == object:
if opname in ["__mul__", "__rmul__"]:
mark = pytest.mark.xfail(
reason="the Series.combine step raises but not the Series method."
)
request.node.add_marker(mark)
frame_scalar_exc = TypeError
self.frame_scalar_exc = frame_scalar_exc
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)


class TestPrinting(BaseNumPyTests, base.BasePrintingTests):
pass


class TestReduce(BaseNumPyTests, base.BaseReduceTests):
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
if ser.dtype.kind == "O":
return op_name in ["sum", "min", "max", "any", "all"]
Expand All @@ -328,8 +317,6 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
pass


class TestMissing(BaseNumPyTests, base.BaseMissingTests):
@skip_nested
def test_fillna_series(self, data_missing):
# Non-scalar "scalar" values.
Expand All @@ -340,12 +327,6 @@ def test_fillna_frame(self, data_missing):
# Non-scalar "scalar" values.
super().test_fillna_frame(data_missing)


class TestReshaping(BaseNumPyTests, base.BaseReshapingTests):
pass


class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
@skip_nested
def test_setitem_invalid(self, data, invalid_scalar):
# object dtype can hold anything, so doesn't raise
Expand Down Expand Up @@ -431,11 +412,25 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
expected = pd.DataFrame({"data": data.to_numpy()})
tm.assert_frame_equal(result, expected, check_column_type=False)

@pytest.mark.xfail(reason="NumpyEADtype is unpacked")
def test_index_from_listlike_with_dtype(self, data):
super().test_index_from_listlike_with_dtype(data)

@skip_nested
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
pass
@skip_nested
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data, request)

@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
def test_compare_array(self, data, comparison_op):
super().test_compare_array(data, comparison_op)

def test_compare_scalar(self, data, comparison_op, request):
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
request.applymarker(mark)
super().test_compare_scalar(data, comparison_op)


class Test2DCompat(BaseNumPyTests, base.NDArrayBacked2DTests):
class Test2DCompat(base.NDArrayBacked2DTests):
pass

0 comments on commit bb14870

Please sign in to comment.