Skip to content

Commit

Permalink
find offset fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan Jain committed Dec 8, 2023
1 parent 46c8da3 commit 5c901a3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ Strings
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
- Bug in :meth:`Series.str.find` when ``start < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56411`)
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,7 +2193,8 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
start_offset = max(0, start)
offset_result = pc.add(result, start_offset)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,14 @@ def test_str_find(sub, start, end, exp, exp_typ):
tm.assert_series_equal(result, expected)


def test_str_find_negative_start():
# GH 56411
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="b", start=-1000, end=3)
expected = pd.Series([1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


def test_str_find_notimplemented():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(NotImplementedError, match="find not implemented"):
Expand Down

0 comments on commit 5c901a3

Please sign in to comment.