Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Ensure "string[pyarrow]" type is preserved when calling extractall #55534

1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Bug fixes
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
- Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`)
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
- Fixed bug in :meth:`Series.str.extractall` for :class:`ArrowDtype` dtype being converted to object (:issue:`53846`)
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)
-

Expand Down
5 changes: 2 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3449,10 +3449,9 @@ def _result_dtype(arr):
# when the list of values is empty.
from pandas.core.arrays.string_ import StringDtype

if isinstance(arr.dtype, StringDtype):
if isinstance(arr.dtype, (ArrowDtype, StringDtype)):
return arr.dtype
else:
return object
return object


def _get_single_group_name(regex: re.Pattern) -> Hashable:
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/strings/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re

import numpy as np
import pyarrow as pa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to import this in test_extractall_preserves_dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mroeschke thanks, moved the import. I also added the requirement to the pypy and numpydev actions yaml files but some tests are still failing, so before making any additional changes, I thought I'd ask what else might need to be changed so ci can run. Looks like the meta.yaml file needs it added as a requirement as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding pyarrow to the dependency files you can structure the test like

def test_whatever():
    pa = pytest.importorskip("pyarrow")
    series = ...

So the test will be skipped if pyarrow is not installed or you will have pyarrow accessible as pa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, fair enough 👍 Thank you. There's still one spec failing but it's also failing in main, it's due to a Numpy deprecation warning as far as I can tell.

import pytest

from pandas import (
Expand All @@ -11,6 +12,7 @@
Series,
_testing as tm,
)
from pandas.core.dtypes.dtypes import ArrowDtype


def test_extract_expand_kwarg_wrong_type_raises(any_string_dtype):
Expand Down Expand Up @@ -706,3 +708,18 @@ def test_extractall_same_as_extract_subject_index(any_string_dtype):
has_match_index = s.str.extractall(pattern_one_noname)
no_match_index = has_match_index.xs(0, level="match")
tm.assert_frame_equal(extract_one_noname, no_match_index)


@pytest.mark.parametrize(
"data, expected_dtype",
[
(Series(["abc", "ab"], dtype=ArrowDtype(pa.string())), "string[pyarrow]"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just test this case since I think the other cases are tested

(Series(["abc", "ab"], dtype="string"), "string[python]"),
(Series(["abc", "ab"]), "object"),
]
)
def test_extractall_preserves_dtype(data, expected_dtype):
# Ensure that when extractall is called on a series with specific dtypes set, that
# the dtype is preserved in the resulting DataFrame's column.
result = data.str.extractall("(ab)")
assert result.dtypes[0] == expected_dtype
Loading