From 0c17c961b18da868fe14a096f046be8f55cd72b3 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sun, 15 Oct 2023 20:54:22 +0200 Subject: [PATCH] Backport PR #55384 on branch 2.1.x (BUG: idxmax raising for arrow strings) (#55531) BUG: idxmax raising for arrow strings (#55384) (cherry picked from commit 68e3c4b2f855e6e9a8469aeca6eb73ae60327160) --- pandas/core/arrays/arrow/array.py | 11 ++++++++++- pandas/core/arrays/string_arrow.py | 11 +++++++++++ pandas/tests/frame/test_reductions.py | 9 +++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 85351945cf29c..636a22fcffe3d 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1596,6 +1596,15 @@ def _reduce( ------ TypeError : subclass does not define reductions """ + result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) + if isinstance(result, pa.Array): + return type(self)(result) + else: + return result + + def _reduce_calc( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs) if keepdims: @@ -1606,7 +1615,7 @@ def _reduce( [pa_result], type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]), ) - return type(self)(result) + return result if pc.is_null(pa_result).as_py(): return self.dtype.na_value diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 9b3245183bd57..2eef240af53f8 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -445,6 +445,17 @@ def _str_rstrip(self, to_strip=None): result = pc.utf8_rtrim(self._pa_array, characters=to_strip) return type(self)(result) + def _reduce( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): + result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) + if name in ("argmin", "argmax") and isinstance(result, pa.Array): + return self._convert_int_dtype(result) + elif isinstance(result, pa.Array): + return type(self)(result) + else: + return result + def _convert_int_dtype(self, result): return Int64Dtype().__from_arrow__(result) diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index 4c83c810e8cec..0b501903ad71f 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -1073,6 +1073,15 @@ def test_idxmax_arrow_types(self): expected = Series([2, 1], index=["a", "b"]) tm.assert_series_equal(result, expected) + df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]") + result = df.idxmax(numeric_only=False) + expected = Series([1], index=["a"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(numeric_only=False) + expected = Series([2], index=["a"]) + tm.assert_series_equal(result, expected) + def test_idxmax_axis_2(self, float_frame): frame = float_frame msg = "No axis named 2 for object type DataFrame"