Skip to content

Commit

Permalink
BUG: assert_series_equal not properly respecting check-dtype (pandas-…
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored and cbpygit committed Jan 2, 2024
1 parent ab666f4 commit a53a9e8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 25 deletions.
10 changes: 8 additions & 2 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,15 @@ def assert_series_equal(
obj=str(obj),
)
else:
# convert both to NumPy if not, check_dtype would raise earlier
lv, rv = left_values, right_values
if isinstance(left_values, ExtensionArray):
lv = left_values.to_numpy()
if isinstance(right_values, ExtensionArray):
rv = right_values.to_numpy()
assert_numpy_array_equal(
left_values,
right_values,
lv,
rv,
check_dtype=check_dtype,
obj=str(obj),
index_values=left.index,
Expand Down
10 changes: 0 additions & 10 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,6 @@ def test_index_from_listlike_with_dtype(self, data):
def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data, request)

@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
def test_compare_array(self, data, comparison_op):
super().test_compare_array(data, comparison_op)

def test_compare_scalar(self, data, comparison_op, request):
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
request.applymarker(mark)
super().test_compare_scalar(data, comparison_op)


class Test2DCompat(base.NDArrayBacked2DTests):
pass
10 changes: 2 additions & 8 deletions pandas/tests/util/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,7 @@ def test_assert_frame_equal_extension_dtype_mismatch():
"\\[right\\]: int[32|64]"
)

# TODO: this shouldn't raise (or should raise a better error message)
# https://github.com/pandas-dev/pandas/issues/56131
with pytest.raises(AssertionError, match="classes are different"):
tm.assert_frame_equal(left, right, check_dtype=False)
tm.assert_frame_equal(left, right, check_dtype=False)

with pytest.raises(AssertionError, match=msg):
tm.assert_frame_equal(left, right, check_dtype=True)
Expand Down Expand Up @@ -246,7 +243,6 @@ def test_assert_frame_equal_ignore_extension_dtype_mismatch():
tm.assert_frame_equal(left, right, check_dtype=False)


@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class():
# https://github.com/pandas-dev/pandas/issues/35715
left = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
Expand Down Expand Up @@ -300,9 +296,7 @@ def test_frame_equal_mixed_dtypes(frame_or_series, any_numeric_ea_dtype, indexer
dtypes = (any_numeric_ea_dtype, "int64")
obj1 = frame_or_series([1, 2], dtype=dtypes[indexer[0]])
obj2 = frame_or_series([1, 2], dtype=dtypes[indexer[1]])
msg = r'(Series|DataFrame.iloc\[:, 0\] \(column name="0"\) classes) are different'
with pytest.raises(AssertionError, match=msg):
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)


def test_assert_frame_equal_check_like_different_indexes():
Expand Down
16 changes: 11 additions & 5 deletions pandas/tests/util/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,7 @@ def test_assert_series_equal_extension_dtype_mismatch():
\\[left\\]: Int64
\\[right\\]: int[32|64]"""

# TODO: this shouldn't raise (or should raise a better error message)
# https://github.com/pandas-dev/pandas/issues/56131
with pytest.raises(AssertionError, match="Series classes are different"):
tm.assert_series_equal(left, right, check_dtype=False)
tm.assert_series_equal(left, right, check_dtype=False)

with pytest.raises(AssertionError, match=msg):
tm.assert_series_equal(left, right, check_dtype=True)
Expand Down Expand Up @@ -372,7 +369,6 @@ def test_assert_series_equal_ignore_extension_dtype_mismatch():
tm.assert_series_equal(left, right, check_dtype=False)


@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class():
# https://github.com/pandas-dev/pandas/issues/35715
left = Series([1, 2, 3], dtype="Int64")
Expand Down Expand Up @@ -456,3 +452,13 @@ def test_large_unequal_ints(dtype):
right = Series([1577840521123543], dtype=dtype)
with pytest.raises(AssertionError, match="Series are different"):
tm.assert_series_equal(left, right)


@pytest.mark.parametrize("dtype", [None, object])
@pytest.mark.parametrize("check_exact", [True, False])
@pytest.mark.parametrize("val", [3, 3.5])
def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
# GH#56651
left = Series([1, 2, val], dtype=dtype)
right = Series(pd.array([1, 2, val]))
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)

0 comments on commit a53a9e8

Please sign in to comment.