Skip to content

Commit

Permalink
String dtype: fix isin() values handling for python storage (#59759)
Browse files Browse the repository at this point in the history
* String dtype: fix isin() values handling for python storage

* address feedback
  • Loading branch information
jorisvandenbossche authored Sep 12, 2024
1 parent 2c49f55 commit 0d2505d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
9 changes: 8 additions & 1 deletion pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,13 @@ def string_storage(request):
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
("python", np.nan),
]
],
ids=[
"string=string[python]",
"string=string[pyarrow]",
"string=str[pyarrow]",
"string=str[python]",
],
)
def string_dtype_arguments(request):
"""
Expand Down Expand Up @@ -1369,6 +1375,7 @@ def dtype_backend(request):

# Alias so we can test with cartesian product of string_storage
string_storage2 = string_storage
string_dtype_arguments2 = string_dtype_arguments


@pytest.fixture(params=tm.BYTES_DTYPES)
Expand Down
20 changes: 20 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
nanops,
ops,
)
from pandas.core.algorithms import isin
from pandas.core.array_algos import masked_reductions
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.floating import (
Expand All @@ -65,6 +66,7 @@
import pyarrow

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
DtypeObj,
Expand Down Expand Up @@ -735,6 +737,24 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
# base class implementation that uses __setitem__
ExtensionArray._putmask(self, mask, value)

def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
if isinstance(values, BaseStringArray) or (
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)
):
values = values.astype(self.dtype, copy=False)
else:
if not lib.is_string_array(np.asarray(values), skipna=True):
values = np.array(
[val for val in values if isinstance(val, str) or isna(val)],
dtype=object,
)
if not len(values):
return np.zeros(self.shape, dtype=bool)

values = self._from_sequence(values, dtype=self.dtype)

return isin(np.asarray(self), np.asarray(values))

def astype(self, dtype, copy: bool = True):
dtype = pandas_dtype(dtype)

Expand Down
41 changes: 36 additions & 5 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def dtype(string_dtype_arguments):
return pd.StringDtype(storage=storage, na_value=na_value)


@pytest.fixture
def dtype2(string_dtype_arguments2):
storage, na_value = string_dtype_arguments2
return pd.StringDtype(storage=storage, na_value=na_value)


@pytest.fixture
def cls(dtype):
"""Fixture giving array type from parametrized 'dtype'"""
Expand Down Expand Up @@ -662,11 +668,7 @@ def test_isin(dtype, fixed_now_ts):
tm.assert_series_equal(result, expected)

result = s.isin(["a", pd.NA])
if dtype.storage == "python" and dtype.na_value is np.nan:
# TODO(infer_string) we should make this consistent
expected = pd.Series([True, False, False])
else:
expected = pd.Series([True, False, True])
expected = pd.Series([True, False, True])
tm.assert_series_equal(result, expected)

result = s.isin([])
Expand All @@ -677,6 +679,35 @@ def test_isin(dtype, fixed_now_ts):
expected = pd.Series([True, False, False])
tm.assert_series_equal(result, expected)

result = s.isin([fixed_now_ts])
expected = pd.Series([False, False, False])
tm.assert_series_equal(result, expected)


def test_isin_string_array(dtype, dtype2):
s = pd.Series(["a", "b", None], dtype=dtype)

result = s.isin(pd.array(["a", "c"], dtype=dtype2))
expected = pd.Series([True, False, False])
tm.assert_series_equal(result, expected)

result = s.isin(pd.array(["a", None], dtype=dtype2))
expected = pd.Series([True, False, True])
tm.assert_series_equal(result, expected)


def test_isin_arrow_string_array(dtype):
pa = pytest.importorskip("pyarrow")
s = pd.Series(["a", "b", None], dtype=dtype)

result = s.isin(pd.array(["a", "c"], dtype=pd.ArrowDtype(pa.string())))
expected = pd.Series([True, False, False])
tm.assert_series_equal(result, expected)

result = s.isin(pd.array(["a", None], dtype=pd.ArrowDtype(pa.string())))
expected = pd.Series([True, False, True])
tm.assert_series_equal(result, expected)


def test_setitem_scalar_with_mask_validation(dtype):
# https://github.com/pandas-dev/pandas/issues/47628
Expand Down

0 comments on commit 0d2505d

Please sign in to comment.