Skip to content

Commit

Permalink
ENH: Implement str.extract for ArrowDtype (#56334)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Dec 5, 2023
1 parent 2718b4e commit df7498f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ Other enhancements
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)
- Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as "BMS" (:issue:`56243`)
- Improved error message when constructing :class:`Period` with invalid offsets such as "QS" (:issue:`55785`)

Expand Down
16 changes: 13 additions & 3 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,9 +2297,19 @@ def _str_encode(self, encoding: str, errors: str = "strict"):
return type(self)(pa.chunked_array(result))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)
if flags:
raise NotImplementedError("Only flags=0 is implemented.")
groups = re.compile(pat).groupindex.keys()
if len(groups) == 0:
raise ValueError(f"{pat=} must contain a symbolic group name.")
result = pc.extract_regex(self._pa_array, pat)
if expand:
return {
col: type(self)(pc.struct_field(result, [i]))
for col, i in zip(groups, range(result.type.num_fields))
}
else:
return type(self)(pc.struct_field(result, [0]))

def _str_findall(self, pat: str, flags: int = 0):
regex = re.compile(pat, flags=flags)
Expand Down
36 changes: 31 additions & 5 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,14 +2159,40 @@ def test_str_rsplit():
tm.assert_frame_equal(result, expected)


def test_str_unsupported_extract():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(
NotImplementedError, match="str.extract not supported with pd.ArrowDtype"
):
def test_str_extract_non_symbolic():
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
with pytest.raises(ValueError, match="pat=.* must contain a symbolic group name."):
ser.str.extract(r"[ab](\d)")


@pytest.mark.parametrize("expand", [True, False])
def test_str_extract(expand):
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
result = ser.str.extract(r"(?P<letter>[ab])(?P<digit>\d)", expand=expand)
expected = pd.DataFrame(
{
"letter": ArrowExtensionArray(pa.array(["a", "b", None])),
"digit": ArrowExtensionArray(pa.array(["1", "2", None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_extract_expand():
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
result = ser.str.extract(r"[ab](?P<digit>\d)", expand=True)
expected = pd.DataFrame(
{
"digit": ArrowExtensionArray(pa.array(["1", "2", None])),
}
)
tm.assert_frame_equal(result, expected)

result = ser.str.extract(r"[ab](?P<digit>\d)", expand=False)
expected = pd.Series(ArrowExtensionArray(pa.array(["1", "2", None])), name="digit")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"])
def test_duration_from_strings_with_nat(unit):
# GH51175
Expand Down

0 comments on commit df7498f

Please sign in to comment.