Skip to content

Commit

Permalink
Backport PR #60195 on branch 2.3.x (BUG (string dtype): fix where() f…
Browse files Browse the repository at this point in the history
…or string dtype with python storage) (#60202)

Backport PR #60195: BUG (string dtype): fix where() for string dtype with python storage

Co-authored-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
meeseeksmachine and jorisvandenbossche authored Nov 5, 2024
1 parent 70e8a3b commit b9f1bc6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 6 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
# base class implementation that uses __setitem__
ExtensionArray._putmask(self, mask, value)

def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
# the super() method NDArrayBackedExtensionArray._where uses
# np.putmask which doesn't properly handle None/pd.NA, so using the
# base class implementation that uses __setitem__
return ExtensionArray._where(self, mask, value)

def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
if isinstance(values, BaseStringArray) or (
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)
Expand Down
18 changes: 6 additions & 12 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.common import is_scalar

import pandas as pd
Expand Down Expand Up @@ -985,9 +983,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
obj.mask(mask, null)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@given(data=OPTIONAL_ONE_OF_ALL)
def test_where_inplace_casting(data):
# GH 22051
Expand Down Expand Up @@ -1084,19 +1079,18 @@ def test_where_producing_ea_cond_for_np_dtype():
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False
)
@pytest.mark.parametrize(
"replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)]
)
def test_where_int_overflow(replacement, using_infer_string, request):
def test_where_int_overflow(replacement, using_infer_string):
# GH 31687
df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]])
if using_infer_string and replacement not in (None, "snake"):
request.node.add_marker(
pytest.mark.xfail(reason="Can't set non-string into string column")
)
with pytest.raises(
TypeError, match="Cannot set non-string value|Scalar must be NA or str"
):
df.where(pd.notnull(df), replacement)
return
result = df.where(pd.notnull(df), replacement)
expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])

Expand Down

0 comments on commit b9f1bc6

Please sign in to comment.