Skip to content

Commit

Permalink
BUG: Ensure "string[pyarrow]" type is preserved when calling extracta…
Browse files Browse the repository at this point in the history
…ll (#55534)

* Ensure "string[pyarrow]" type is preserved when calling extractall

* Add whatsnew note

* Add test and fix whatsnew entry

* Fix test case and move import

* Add pyarrow requirement to actions

* Add pyarrow requirement as importorskip
  • Loading branch information
ABizzinotto authored Oct 19, 2023
1 parent 192aec7 commit c9dc91d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Bug fixes
- Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`)
- Fixed bug in :meth:`Series.floordiv` for :class:`ArrowDtype` (:issue:`55561`)
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
- Fixed bug in :meth:`Series.str.extractall` for :class:`ArrowDtype` dtype being converted to object (:issue:`53846`)
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)

.. ---------------------------------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3449,10 +3449,9 @@ def _result_dtype(arr):
# when the list of values is empty.
from pandas.core.arrays.string_ import StringDtype

if isinstance(arr.dtype, StringDtype):
if isinstance(arr.dtype, (ArrowDtype, StringDtype)):
return arr.dtype
else:
return object
return object


def _get_single_group_name(regex: re.Pattern) -> Hashable:
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/strings/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.dtypes import ArrowDtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -706,3 +708,12 @@ def test_extractall_same_as_extract_subject_index(any_string_dtype):
has_match_index = s.str.extractall(pattern_one_noname)
no_match_index = has_match_index.xs(0, level="match")
tm.assert_frame_equal(extract_one_noname, no_match_index)


def test_extractall_preserves_dtype():
# Ensure that when extractall is called on a series with specific dtypes set, that
# the dtype is preserved in the resulting DataFrame's column.
pa = pytest.importorskip("pyarrow")

result = Series(["abc", "ab"], dtype=ArrowDtype(pa.string())).str.extractall("(ab)")
assert result.dtypes[0] == "string[pyarrow]"

0 comments on commit c9dc91d

Please sign in to comment.