diff --git a/doc/source/whatsnew/v2.1.2.rst b/doc/source/whatsnew/v2.1.2.rst index 28b271f9cf446..97a718dd496e9 100644 --- a/doc/source/whatsnew/v2.1.2.rst +++ b/doc/source/whatsnew/v2.1.2.rst @@ -31,6 +31,7 @@ Bug fixes - Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`) - Fixed bug in :meth:`Series.floordiv` for :class:`ArrowDtype` (:issue:`55561`) - Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`) +- Fixed bug in :meth:`Series.str.extractall` for :class:`ArrowDtype` dtype being converted to object (:issue:`53846`) - Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 71d6f9c58e2c2..58b904fd31b6a 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3449,10 +3449,9 @@ def _result_dtype(arr): # when the list of values is empty. from pandas.core.arrays.string_ import StringDtype - if isinstance(arr.dtype, StringDtype): + if isinstance(arr.dtype, (ArrowDtype, StringDtype)): return arr.dtype - else: - return object + return object def _get_single_group_name(regex: re.Pattern) -> Hashable: diff --git a/pandas/tests/strings/test_extract.py b/pandas/tests/strings/test_extract.py index 4773c13aad376..b8319e90e09a8 100644 --- a/pandas/tests/strings/test_extract.py +++ b/pandas/tests/strings/test_extract.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from pandas.core.dtypes.dtypes import ArrowDtype + from pandas import ( DataFrame, Index, @@ -706,3 +708,12 @@ def test_extractall_same_as_extract_subject_index(any_string_dtype): has_match_index = s.str.extractall(pattern_one_noname) no_match_index = has_match_index.xs(0, level="match") tm.assert_frame_equal(extract_one_noname, no_match_index) + + +def test_extractall_preserves_dtype(): + # Ensure that when extractall is called on a series with specific dtypes set, that + # the dtype is preserved in the resulting DataFrame's column. + pa = pytest.importorskip("pyarrow") + + result = Series(["abc", "ab"], dtype=ArrowDtype(pa.string())).str.extractall("(ab)") + assert result.dtypes[0] == "string[pyarrow]"