Skip to content

Commit

Permalink
BUG (string dtype): fix where() for string dtype with python storage (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored Nov 5, 2024
1 parent 169b00e commit cf52dec
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 6 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
# base class implementation that uses __setitem__
ExtensionArray._putmask(self, mask, value)

def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
# the super() method NDArrayBackedExtensionArray._where uses
# np.putmask which doesn't properly handle None/pd.NA, so using the
# base class implementation that uses __setitem__
return ExtensionArray._where(self, mask, value)

def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
if isinstance(values, BaseStringArray) or (
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)
Expand Down
18 changes: 6 additions & 12 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.common import is_scalar

import pandas as pd
Expand Down Expand Up @@ -940,9 +938,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
obj.mask(mask, null)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@given(data=OPTIONAL_ONE_OF_ALL)
def test_where_inplace_casting(data):
# GH 22051
Expand Down Expand Up @@ -1023,19 +1018,18 @@ def test_where_producing_ea_cond_for_np_dtype():
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False
)
@pytest.mark.parametrize(
"replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)]
)
def test_where_int_overflow(replacement, using_infer_string, request):
def test_where_int_overflow(replacement, using_infer_string):
# GH 31687
df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]])
if using_infer_string and replacement not in (None, "snake"):
request.node.add_marker(
pytest.mark.xfail(reason="Can't set non-string into string column")
)
with pytest.raises(
TypeError, match="Cannot set non-string value|Scalar must be NA or str"
):
df.where(pd.notnull(df), replacement)
return
result = df.where(pd.notnull(df), replacement)
expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])

Expand Down

0 comments on commit cf52dec

Please sign in to comment.