From b9f1bc6a773aee3f3bde6a226c7c43ccc0f04de4 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:29:32 -0800 Subject: [PATCH] Backport PR #60195 on branch 2.3.x (BUG (string dtype): fix where() for string dtype with python storage) (#60202) Backport PR #60195: BUG (string dtype): fix where() for string dtype with python storage Co-authored-by: Joris Van den Bossche --- pandas/core/arrays/string_.py | 6 ++++++ pandas/tests/frame/indexing/test_where.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index f5c5cb2a45034..92c274453b9d1 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: # base class implementation that uses __setitem__ ExtensionArray._putmask(self, mask, value) + def _where(self, mask: npt.NDArray[np.bool_], value) -> Self: + # the super() method NDArrayBackedExtensionArray._where uses + # np.putmask which doesn't properly handle None/pd.NA, so using the + # base class implementation that uses __setitem__ + return ExtensionArray._where(self, mask, value) + def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]: if isinstance(values, BaseStringArray) or ( isinstance(values, ExtensionArray) and is_string_dtype(values.dtype) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index f0d868a4cb583..40506c90f3295 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -6,8 +6,6 @@ from pandas._config import using_string_dtype -from pandas.compat import HAS_PYARROW - from pandas.core.dtypes.common import is_scalar import pandas as pd @@ -985,9 +983,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype): obj.mask(mask, null) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" -) @given(data=OPTIONAL_ONE_OF_ALL) def test_where_inplace_casting(data): # GH 22051 @@ -1084,19 +1079,18 @@ def test_where_producing_ea_cond_for_np_dtype(): tm.assert_frame_equal(result, expected) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False -) @pytest.mark.parametrize( "replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)] ) -def test_where_int_overflow(replacement, using_infer_string, request): +def test_where_int_overflow(replacement, using_infer_string): # GH 31687 df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]]) if using_infer_string and replacement not in (None, "snake"): - request.node.add_marker( - pytest.mark.xfail(reason="Can't set non-string into string column") - ) + with pytest.raises( + TypeError, match="Cannot set non-string value|Scalar must be NA or str" + ): + df.where(pd.notnull(df), replacement) + return result = df.where(pd.notnull(df), replacement) expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])