From 702449db1331960857741d2169650d9b16a471df Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:25:12 -0800 Subject: [PATCH] BUG: str.split for ArrowDtype with pat=None --- doc/source/whatsnew/v2.1.4.rst | 3 ++- pandas/core/arrays/arrow/array.py | 22 +++++++++++++++------- pandas/tests/extension/test_arrow.py | 9 +++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v2.1.4.rst b/doc/source/whatsnew/v2.1.4.rst index 4ef6a2463ee16..927c0ee4f532d 100644 --- a/doc/source/whatsnew/v2.1.4.rst +++ b/doc/source/whatsnew/v2.1.4.rst @@ -24,11 +24,12 @@ Bug fixes - Bug in :class:`Series` constructor raising DeprecationWarning when ``index`` is a list of :class:`Series` (:issue:`55228`) - Bug in :meth:`Index.__getitem__` returning wrong result for Arrow dtypes and negative stepsize (:issue:`55832`) - Fixed bug in :func:`to_numeric` converting to extension dtype for ``string[pyarrow_numpy]`` dtype (:issue:`56179`) -- Fixed bug in :meth:`.DataFrameGroupBy.min()` and :meth:`.DataFrameGroupBy.max()` not preserving extension dtype for empty object (:issue:`55619`) +- Fixed bug in :meth:`.DataFrameGroupBy.min` and :meth:`.DataFrameGroupBy.max` not preserving extension dtype for empty object (:issue:`55619`) - Fixed bug in :meth:`DataFrame.__setitem__` casting :class:`Index` with object-dtype to PyArrow backed strings when ``infer_string`` option is set (:issue:`55638`) - Fixed bug in :meth:`DataFrame.to_hdf` raising when columns have ``StringDtype`` (:issue:`55088`) - Fixed bug in :meth:`Index.insert` casting object-dtype to PyArrow backed strings when ``infer_string`` option is set (:issue:`55638`) - Fixed bug in :meth:`Series.mode` not keeping object dtype when ``infer_string`` is set (:issue:`56183`) +- Fixed bug in :meth:`Series.str.split` and :meth:`Series.str.rsplit` when ``pat=None`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56271`) - Fixed bug in :meth:`Series.str.translate` losing object dtype when string option is set (:issue:`56152`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d162b66e5d369..ee950ed463893 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import operator import re import textwrap @@ -2351,18 +2352,25 @@ def _str_split( ): if n in {-1, 0}: n = None - if regex: - split_func = pc.split_pattern_regex + if pat is None: + split_func = pc.utf8_split_whitespace + elif regex: + split_func = functools.partial(pc.split_pattern_regex, pattern=pat) else: - split_func = pc.split_pattern - return type(self)(split_func(self._pa_array, pat, max_splits=n)) + split_func = functools.partial(pc.split_pattern, pattern=pat) + return type(self)(split_func(self._pa_array, max_splits=n)) def _str_rsplit(self, pat: str | None = None, n: int | None = -1): if n in {-1, 0}: n = None - return type(self)( - pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) - ) + if pat is None: + return type(self)( + pc.utf8_split_whitespace(self._pa_array, max_splits=n, reverse=True) + ) + else: + return type(self)( + pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) + ) def _str_translate(self, table: dict[int, str]): predicate = lambda val: val.translate(table) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 7131a50956a7d..b1f989802bf59 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2076,6 +2076,15 @@ def test_str_partition(): tm.assert_series_equal(result, expected) +@pytest.mark.parametrize("method", ["rsplit", "split"]) +def test_str_split_pat_none(method): + # GH 56271 + ser = pd.Series(["a1 cbc\nb", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series(ArrowExtensionArray(pa.array([["a1", "cbc", "b"], None]))) + tm.assert_series_equal(result, expected) + + def test_str_split(): # GH 52401 ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))