diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 70039cc697b8a..cda4da9d76c42 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -604,6 +604,7 @@ Strings - Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`) - Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`) +- Bug in comparison operations for ``dtype="string[pyarrow_numpy]"`` raising if dtypes can't be compared (:issue:`56008`) Interval ^^^^^^^^ diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 32ab3054c0f51..50cd052f80abd 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -41,6 +41,7 @@ BaseStringArray, StringDtype, ) +from pandas.core.ops import invalid_comparison from pandas.core.strings.object_array import ObjectStringArrayMixin if not pa_version_under10p1: @@ -662,7 +663,10 @@ def _convert_int_dtype(self, result): return result def _cmp_method(self, other, op): - result = super()._cmp_method(other, op) + try: + result = super()._cmp_method(other, op) + except pa.ArrowNotImplementedError: + return invalid_comparison(self, other, op) if op == operator.ne: return result.to_numpy(np.bool_, na_value=True) else: diff --git a/pandas/tests/series/test_logical_ops.py b/pandas/tests/series/test_logical_ops.py index 153b4bfaaf444..d9c94e871bd4b 100644 --- a/pandas/tests/series/test_logical_ops.py +++ b/pandas/tests/series/test_logical_ops.py @@ -530,3 +530,19 @@ def test_int_dtype_different_index_not_bool(self): result = ser1 ^ ser2 tm.assert_series_equal(result, expected) + + def test_pyarrow_numpy_string_invalid(self): + # GH#56008 + pytest.importorskip("pyarrow") + ser = Series([False, True]) + ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]") + result = ser == ser2 + expected = Series(False, index=ser.index) + tm.assert_series_equal(result, expected) + + result = ser != ser2 + expected = Series(True, index=ser.index) + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser2