Skip to content

Commit

Permalink
BUG: __eq__ raising for new arrow string dtype for incompatible objects
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Nov 29, 2023
1 parent 0e8174f commit b951690
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ Strings
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
- Bug in comparison operations for ``dtype="string[pyarrow_numpy]"`` raising if dtypes can't be compared (:issue:`56008`)

Interval
^^^^^^^^
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
BaseStringArray,
StringDtype,
)
from pandas.core.ops import invalid_comparison
from pandas.core.strings.object_array import ObjectStringArrayMixin

if not pa_version_under10p1:
Expand Down Expand Up @@ -662,7 +663,10 @@ def _convert_int_dtype(self, result):
return result

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
try:
result = super()._cmp_method(other, op)
except pa.lib.ArrowNotImplementedError:
return invalid_comparison(self, other, op)
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True) -> Series:
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/series/test_logical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,18 @@ def test_int_dtype_different_index_not_bool(self):

result = ser1 ^ ser2
tm.assert_series_equal(result, expected)

def test_pyarrow_numpy_string_invalid(self):
# GH#56008
ser = Series([False, True])
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
result = ser == ser2
expected = Series(False, index=ser.index)
tm.assert_series_equal(result, expected)

result = ser != ser2
expected = Series(True, index=ser.index)
tm.assert_series_equal(result, expected)

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser2

0 comments on commit b951690

Please sign in to comment.