Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Arrow String Array that is compatible with NumPy semantics #54533

Merged
merged 24 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def nullable_string_dtype(request):
params=[
"python",
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
]
)
def string_storage(request):
phofl marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1380,6 +1381,7 @@ def object_dtype(request):
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
]
)
def any_string_dtype(request):
Expand Down
77 changes: 77 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from typing import Literal

import numpy as np

from pandas.compat import pa_version_under7p0

if not pa_version_under7p0:
import pyarrow as pa
import pyarrow.compute as pc


class ArrowStringArrayMixin:
def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
if stop is None:
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)
79 changes: 15 additions & 64 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from pandas.core import roperator
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArraySupportsAnyAll,
Expand Down Expand Up @@ -184,7 +185,10 @@ def to_pyarrow_type(


class ArrowExtensionArray(
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
OpsMixin,
ExtensionArraySupportsAnyAll,
ArrowStringArrayMixin,
BaseStringArrayMethods,
):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
Expand Down Expand Up @@ -246,6 +250,12 @@ def __init__(self, values: pa.Array | pa.ChunkedArray) -> None:
)
self._dtype = ArrowDtype(self._pa_array.type)

def __dir__(self):
o = set(dir(type(self)))
o.update(self.__dict__)
o.update(set(dir(ArrowStringArrayMixin)))
return list(o)

@classmethod
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False):
"""
Expand Down Expand Up @@ -502,7 +512,10 @@ def __getitem__(self, item: PositionalIndexer):
if isinstance(item, np.ndarray):
if not len(item):
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
if self._dtype.name == "string" and self._dtype.storage in (
"pyarrow",
"pyarrow_numpy",
):
pa_dtype = pa.string()
else:
pa_dtype = self._dtype.pyarrow_dtype
Expand Down Expand Up @@ -1987,24 +2000,6 @@ def _str_count(self, pat: str, flags: int = 0):
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
Expand Down Expand Up @@ -2089,26 +2084,6 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
)
return type(self)(result)

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_join(self, sep: str):
if pa.types.is_string(self._pa_array.type):
result = self._apply_elementwise(list)
Expand Down Expand Up @@ -2138,15 +2113,6 @@ def _str_slice(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_isalnum(self):
return type(self)(pc.utf8_is_alnum(self._pa_array))

Expand All @@ -2171,18 +2137,9 @@ def _str_isspace(self):
def _str_istitle(self):
return type(self)(pc.utf8_is_title(self._pa_array))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_isupper(self):
return type(self)(pc.utf8_is_upper(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_len(self):
return type(self)(pc.utf8_length(self._pa_array))

Expand Down Expand Up @@ -2223,12 +2180,6 @@ def _str_removeprefix(self, prefix: str):
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_casefold(self):
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
Expand Down
21 changes: 16 additions & 5 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StringDtype(StorageExtensionDtype):

Parameters
----------
storage : {"python", "pyarrow"}, optional
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
If not given, the value of ``pd.options.mode.string_storage``.

Attributes
Expand Down Expand Up @@ -108,11 +108,11 @@ def na_value(self) -> libmissing.NAType:
def __init__(self, storage=None) -> None:
if storage is None:
storage = get_option("mode.string_storage")
if storage not in {"python", "pyarrow"}:
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
raise ValueError(
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
)
if storage == "pyarrow" and pa_version_under7p0:
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under7p0:
raise ImportError(
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
)
Expand Down Expand Up @@ -160,6 +160,8 @@ def construct_from_string(cls, string):
return cls(storage="python")
elif string == "string[pyarrow]":
return cls(storage="pyarrow")
elif string == "string[pyarrow_numpy]":
return cls(storage="pyarrow_numpy")
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

Expand All @@ -176,12 +178,17 @@ def construct_array_type( # type: ignore[override]
-------
type
"""
from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringArrayNumpySemantics,
)

if self.storage == "python":
return StringArray
else:
elif self.storage == "pyarrow":
return ArrowStringArray
else:
return ArrowStringArrayNumpySemantics

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
Expand All @@ -193,6 +200,10 @@ def __from_arrow__(
from pandas.core.arrays.string_arrow import ArrowStringArray

return ArrowStringArray(array)
elif self.storage == "pyarrow_numpy":
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics

return ArrowStringArrayNumpySemantics(array)
else:
import pyarrow

Expand Down
Loading