Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: assert_series_equal not properly respecting check-dtype #56654

Merged
merged 4 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,15 @@ def assert_series_equal(
obj=str(obj),
)
else:
# convert both to NumPy if not, check_dtype would raise earlier
lv, rv = left_values, right_values
if isinstance(left_values, ExtensionArray):
lv = left_values.to_numpy()
if isinstance(right_values, ExtensionArray):
rv = right_values.to_numpy()
assert_numpy_array_equal(
left_values,
right_values,
lv,
rv,
check_dtype=check_dtype,
obj=str(obj),
index_values=left.index,
Expand Down
10 changes: 0 additions & 10 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,6 @@ def test_index_from_listlike_with_dtype(self, data):
def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data, request)

@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
def test_compare_array(self, data, comparison_op):
super().test_compare_array(data, comparison_op)

def test_compare_scalar(self, data, comparison_op, request):
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
request.applymarker(mark)
super().test_compare_scalar(data, comparison_op)


class Test2DCompat(base.NDArrayBacked2DTests):
pass
10 changes: 2 additions & 8 deletions pandas/tests/util/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,7 @@ def test_assert_frame_equal_extension_dtype_mismatch():
"\\[right\\]: int[32|64]"
)

# TODO: this shouldn't raise (or should raise a better error message)
# https://github.com/pandas-dev/pandas/issues/56131
with pytest.raises(AssertionError, match="classes are different"):
tm.assert_frame_equal(left, right, check_dtype=False)
tm.assert_frame_equal(left, right, check_dtype=False)

with pytest.raises(AssertionError, match=msg):
tm.assert_frame_equal(left, right, check_dtype=True)
Expand Down Expand Up @@ -246,7 +243,6 @@ def test_assert_frame_equal_ignore_extension_dtype_mismatch():
tm.assert_frame_equal(left, right, check_dtype=False)


@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class():
# https://github.com/pandas-dev/pandas/issues/35715
left = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
Expand Down Expand Up @@ -300,9 +296,7 @@ def test_frame_equal_mixed_dtypes(frame_or_series, any_numeric_ea_dtype, indexer
dtypes = (any_numeric_ea_dtype, "int64")
obj1 = frame_or_series([1, 2], dtype=dtypes[indexer[0]])
obj2 = frame_or_series([1, 2], dtype=dtypes[indexer[1]])
msg = r'(Series|DataFrame.iloc\[:, 0\] \(column name="0"\) classes) are different'
with pytest.raises(AssertionError, match=msg):
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)


def test_assert_frame_equal_check_like_different_indexes():
Expand Down
16 changes: 11 additions & 5 deletions pandas/tests/util/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,7 @@ def test_assert_series_equal_extension_dtype_mismatch():
\\[left\\]: Int64
\\[right\\]: int[32|64]"""

# TODO: this shouldn't raise (or should raise a better error message)
# https://github.com/pandas-dev/pandas/issues/56131
with pytest.raises(AssertionError, match="Series classes are different"):
tm.assert_series_equal(left, right, check_dtype=False)
tm.assert_series_equal(left, right, check_dtype=False)

with pytest.raises(AssertionError, match=msg):
tm.assert_series_equal(left, right, check_dtype=True)
Expand Down Expand Up @@ -372,7 +369,6 @@ def test_assert_series_equal_ignore_extension_dtype_mismatch():
tm.assert_series_equal(left, right, check_dtype=False)


@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class():
# https://github.com/pandas-dev/pandas/issues/35715
left = Series([1, 2, 3], dtype="Int64")
Expand Down Expand Up @@ -456,3 +452,13 @@ def test_large_unequal_ints(dtype):
right = Series([1577840521123543], dtype=dtype)
with pytest.raises(AssertionError, match="Series are different"):
tm.assert_series_equal(left, right)


@pytest.mark.parametrize("dtype", [None, object])
@pytest.mark.parametrize("check_exact", [True, False])
@pytest.mark.parametrize("val", [3, 3.5])
def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
# GH#56651
left = Series([1, 2, val], dtype=dtype)
right = Series(pd.array([1, 2, val]))
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)
Loading