diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 8cb4b3f24d435..b39f30538bf56 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -502,6 +502,7 @@ Strings - Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`) - Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`) +- Bug in comparison operations for ``dtype="string[pyarrow_numpy]"`` raising if dtypes can't be compared (:issue:`56008`) Interval ^^^^^^^^ diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 96ebb4901f797..c4047cb4fa31a 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -40,6 +40,7 @@ BaseStringArray, StringDtype, ) +from pandas.core.ops import invalid_comparison from pandas.core.strings.object_array import ObjectStringArrayMixin if not pa_version_under10p1: @@ -662,7 +663,10 @@ def _convert_int_dtype(self, result): return result def _cmp_method(self, other, op): - result = super()._cmp_method(other, op) + try: + result = super()._cmp_method(other, op) + except pa.lib.ArrowNotImplementedError: + return invalid_comparison(self, other, op) return result.to_numpy(np.bool_, na_value=False) def value_counts(self, dropna: bool = True) -> Series: diff --git a/pandas/tests/series/test_logical_ops.py b/pandas/tests/series/test_logical_ops.py index 166f52181fed4..2cab33a31434f 100644 --- a/pandas/tests/series/test_logical_ops.py +++ b/pandas/tests/series/test_logical_ops.py @@ -530,3 +530,18 @@ def test_int_dtype_different_index_not_bool(self): result = ser1 ^ ser2 tm.assert_series_equal(result, expected) + + def test_pyarrow_numpy_string_invalid(self): + # GH#56008 + ser = Series([False, True]) + ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]") + result = ser == ser2 + expected = Series(False, index=ser.index) + tm.assert_series_equal(result, expected) + + result = ser != ser2 + expected = Series(True, index=ser.index) + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser2