Skip to content

Commit

Permalink
TYP: EA.isin (#56423)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 9, 2023
1 parent 4f8bb2b commit 2dcb963
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def fillna(

return super().fillna(value=value, method=method, limit=limit, copy=copy)

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
# short-circuit to return all False array.
if not len(values):
return np.zeros(len(self), dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,15 +1355,15 @@ def equals(self, other: object) -> bool:
equal_na = self.isna() & other.isna() # type: ignore[operator]
return bool((equal_values | equal_na).all())

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Pointwise comparison for set containment in the given values.
Roughly equivalent to `np.array([x in values for x in self])`
Parameters
----------
values : Sequence
values : np.ndarray or ExtensionArray
Returns
-------
Expand Down
11 changes: 2 additions & 9 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2570,7 +2570,7 @@ def describe(self) -> DataFrame:

return result

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Check whether `values` are contained in Categorical.
Expand All @@ -2580,7 +2580,7 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
Parameters
----------
values : set or list-like
values : np.ndarray or ExtensionArray
The sequence of values to test. Passing in a single string will
raise a ``TypeError``. Instead, turn a single string into a
list of one element.
Expand Down Expand Up @@ -2611,13 +2611,6 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
>>> s.isin(['lama'])
array([ True, False, True, False, True, False])
"""
if not is_list_like(values):
values_type = type(values).__name__
raise TypeError(
"only list-like objects are allowed to be passed "
f"to isin(), you passed a `{values_type}`"
)
values = sanitize_array(values, None, None)
null_mask = np.asarray(isna(values))
code_values = self.categories.get_indexer_for(values)
code_values = code_values[null_mask | (code_values >= 0)]
Expand Down
20 changes: 12 additions & 8 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,22 +734,19 @@ def map(self, mapper, na_action=None):
else:
return result.array

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
"""
Compute boolean array of whether each value is found in the
passed set of values.
Parameters
----------
values : set or sequence of values
values : np.ndarray or ExtensionArray
Returns
-------
ndarray[bool]
"""
if not hasattr(values, "dtype"):
values = np.asarray(values)

if values.dtype.kind in "fiuc":
# TODO: de-duplicate with equals, validate_comparison_value
return np.zeros(self.shape, dtype=bool)
Expand Down Expand Up @@ -781,15 +778,22 @@ def isin(self, values) -> npt.NDArray[np.bool_]:

if self.dtype.kind in "mM":
self = cast("DatetimeArray | TimedeltaArray", self)
values = values.as_unit(self.unit)
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "as_unit"
values = values.as_unit(self.unit) # type: ignore[union-attr]

try:
self._check_compatible_with(values)
# error: Argument 1 to "_check_compatible_with" of "DatetimeLikeArrayMixin"
# has incompatible type "ExtensionArray | ndarray[Any, Any]"; expected
# "Period | Timestamp | Timedelta | NaTType"
self._check_compatible_with(values) # type: ignore[arg-type]
except (TypeError, ValueError):
# Includes tzawareness mismatch and IncompatibleFrequencyError
return np.zeros(self.shape, dtype=bool)

return isin(self.asi8, values.asi8)
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "asi8"
return isin(self.asi8, values.asi8) # type: ignore[union-attr]

# ------------------------------------------------------------------
# Null Handling
Expand Down
8 changes: 2 additions & 6 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,12 +1789,8 @@ def contains(self, other):
other < self._right if self.open_right else other <= self._right
)

def isin(self, values) -> npt.NDArray[np.bool_]:
if not hasattr(values, "dtype"):
values = np.array(values)
values = extract_array(values, extract_numpy=True)

if isinstance(values.dtype, IntervalDtype):
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
if isinstance(values, IntervalArray):
if self.closed != values.closed:
# not comparable -> no overlap
return np.zeros(self.shape, dtype=bool)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def take(

# error: Return type "BooleanArray" of "isin" incompatible with return type
# "ndarray" in supertype "ExtensionArray"
def isin(self, values) -> BooleanArray: # type: ignore[override]
def isin(self, values: ArrayLike) -> BooleanArray: # type: ignore[override]
from pandas.core.arrays import BooleanArray

# algorithms.isin will eventually convert values to an ndarray, so no extra
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from collections.abc import Sequence

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
Scalar,
Expand Down Expand Up @@ -212,7 +213,7 @@ def _maybe_convert_setitem_value(self, value):
raise TypeError("Scalar must be NA or str")
return super()._maybe_convert_setitem_value(value)

def isin(self, values) -> npt.NDArray[np.bool_]:
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
value_set = [
pa_scalar.as_py()
for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
Expand Down

0 comments on commit 2dcb963

Please sign in to comment.