Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: Move methods that can be shared with new string dtype #54534

Merged
merged 2 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from typing import Literal

import numpy as np

from pandas.compat import pa_version_under7p0

if not pa_version_under7p0:
import pyarrow as pa
import pyarrow.compute as pc


class ArrowStringArrayMixin:
def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
if stop is None:
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)
68 changes: 5 additions & 63 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from pandas.core import roperator
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArraySupportsAnyAll,
Expand Down Expand Up @@ -184,7 +185,10 @@ def to_pyarrow_type(


class ArrowExtensionArray(
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
OpsMixin,
ExtensionArraySupportsAnyAll,
ArrowStringArrayMixin,
BaseStringArrayMethods,
):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
Expand Down Expand Up @@ -1987,24 +1991,6 @@ def _str_count(self, pat: str, flags: int = 0):
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
Expand Down Expand Up @@ -2089,26 +2075,6 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
)
return type(self)(result)

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_join(self, sep: str):
if pa.types.is_string(self._pa_array.type):
result = self._apply_elementwise(list)
Expand Down Expand Up @@ -2138,15 +2104,6 @@ def _str_slice(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_isalnum(self):
return type(self)(pc.utf8_is_alnum(self._pa_array))

Expand All @@ -2171,18 +2128,9 @@ def _str_isspace(self):
def _str_istitle(self):
return type(self)(pc.utf8_is_title(self._pa_array))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_isupper(self):
return type(self)(pc.utf8_is_upper(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_len(self):
return type(self)(pc.utf8_length(self._pa_array))

Expand Down Expand Up @@ -2223,12 +2171,6 @@ def _str_removeprefix(self, prefix: str):
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_casefold(self):
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
Expand Down