From cf52dec71329797b2af84053d091bd7cfc787486 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 5 Nov 2024 19:55:24 +0100 Subject: [PATCH] BUG (string dtype): fix where() for string dtype with python storage (#60195) --- pandas/core/arrays/string_.py | 6 ++++++ pandas/tests/frame/indexing/test_where.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index c9e53abc31182..f54a5260bd699 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: # base class implementation that uses __setitem__ ExtensionArray._putmask(self, mask, value) + def _where(self, mask: npt.NDArray[np.bool_], value) -> Self: + # the super() method NDArrayBackedExtensionArray._where uses + # np.putmask which doesn't properly handle None/pd.NA, so using the + # base class implementation that uses __setitem__ + return ExtensionArray._where(self, mask, value) + def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]: if isinstance(values, BaseStringArray) or ( isinstance(values, ExtensionArray) and is_string_dtype(values.dtype) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 32a827c25c77a..ff66ea491e308 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -6,8 +6,6 @@ from pandas._config import using_string_dtype -from pandas.compat import HAS_PYARROW - from pandas.core.dtypes.common import is_scalar import pandas as pd @@ -940,9 +938,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype): obj.mask(mask, null) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" -) @given(data=OPTIONAL_ONE_OF_ALL) def test_where_inplace_casting(data): # GH 22051 @@ -1023,19 +1018,18 @@ def test_where_producing_ea_cond_for_np_dtype(): tm.assert_frame_equal(result, expected) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False -) @pytest.mark.parametrize( "replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)] ) -def test_where_int_overflow(replacement, using_infer_string, request): +def test_where_int_overflow(replacement, using_infer_string): # GH 31687 df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]]) if using_infer_string and replacement not in (None, "snake"): - request.node.add_marker( - pytest.mark.xfail(reason="Can't set non-string into string column") - ) + with pytest.raises( + TypeError, match="Cannot set non-string value|Scalar must be NA or str" + ): + df.where(pd.notnull(df), replacement) + return result = df.where(pd.notnull(df), replacement) expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])