Skip to content

Commit

Permalink
Backport PR #54533 on branch 2.1.x (Implement Arrow String Array that…
Browse files Browse the repository at this point in the history
… is compatible with NumPy semantics) (#54713)

Backport PR #54533: Implement Arrow String Array that is compatible with NumPy semantics

Co-authored-by: Patrick Hoefler <[email protected]>
  • Loading branch information
meeseeksmachine and phofl authored Aug 23, 2023
1 parent 5c9b63c commit 2ad36cc
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 53 deletions.
5 changes: 4 additions & 1 deletion pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def nullable_string_dtype(request):
params=[
"python",
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
]
)
def string_storage(request):
Expand All @@ -1329,6 +1330,7 @@ def string_storage(request):
* 'python'
* 'pyarrow'
* 'pyarrow_numpy'
"""
return request.param

Expand Down Expand Up @@ -1380,6 +1382,7 @@ def object_dtype(request):
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
]
)
def any_string_dtype(request):
Expand Down Expand Up @@ -2000,4 +2003,4 @@ def warsaw(request) -> str:

@pytest.fixture()
def arrow_string_storage():
return ("pyarrow",)
return ("pyarrow", "pyarrow_numpy")
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,10 @@ def __getitem__(self, item: PositionalIndexer):
if isinstance(item, np.ndarray):
if not len(item):
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
if self._dtype.name == "string" and self._dtype.storage in (
"pyarrow",
"pyarrow_numpy",
):
pa_dtype = pa.string()
else:
pa_dtype = self._dtype.pyarrow_dtype
Expand Down
21 changes: 16 additions & 5 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StringDtype(StorageExtensionDtype):
Parameters
----------
storage : {"python", "pyarrow"}, optional
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
If not given, the value of ``pd.options.mode.string_storage``.
Attributes
Expand Down Expand Up @@ -108,11 +108,11 @@ def na_value(self) -> libmissing.NAType:
def __init__(self, storage=None) -> None:
if storage is None:
storage = get_option("mode.string_storage")
if storage not in {"python", "pyarrow"}:
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
raise ValueError(
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
)
if storage == "pyarrow" and pa_version_under7p0:
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under7p0:
raise ImportError(
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
)
Expand Down Expand Up @@ -160,6 +160,8 @@ def construct_from_string(cls, string):
return cls(storage="python")
elif string == "string[pyarrow]":
return cls(storage="pyarrow")
elif string == "string[pyarrow_numpy]":
return cls(storage="pyarrow_numpy")
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

Expand All @@ -176,12 +178,17 @@ def construct_array_type( # type: ignore[override]
-------
type
"""
from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringArrayNumpySemantics,
)

if self.storage == "python":
return StringArray
else:
elif self.storage == "pyarrow":
return ArrowStringArray
else:
return ArrowStringArrayNumpySemantics

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
Expand All @@ -193,6 +200,10 @@ def __from_arrow__(
from pandas.core.arrays.string_arrow import ArrowStringArray

return ArrowStringArray(array)
elif self.storage == "pyarrow_numpy":
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics

return ArrowStringArrayNumpySemantics(array)
else:
import pyarrow

Expand Down
149 changes: 135 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -27,6 +28,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand Down Expand Up @@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
# error: Incompatible types in assignment (expression has type "StringDtype",
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
_dtype: StringDtype # type: ignore[assignment]
_storage = "pyarrow"

def __init__(self, values) -> None:
super().__init__(values)
self._dtype = StringDtype(storage="pyarrow")
self._dtype = StringDtype(storage=self._storage)

if not pa.types.is_string(self._pa_array.type) and not (
pa.types.is_dictionary(self._pa_array.type)
Expand Down Expand Up @@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)

if dtype and not (isinstance(dtype, str) and dtype == "string"):
dtype = pandas_dtype(dtype)
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
assert isinstance(dtype, StringDtype) and dtype.storage in (
"pyarrow",
"pyarrow_numpy",
)

if isinstance(scalars, BaseMaskedArray):
# avoid costly conversion to object dtype in ensure_string_array and
Expand Down Expand Up @@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@classmethod
def _result_converter(cls, values, na=None):
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
"""Maybe convert value to be pyarrow compatible."""
if is_scalar(value):
Expand Down Expand Up @@ -313,7 +323,7 @@ def _str_contains(
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result, na=na)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
result = pc.starts_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
result = pc.ends_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand Down Expand Up @@ -369,39 +379,39 @@ def _str_fullmatch(

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
Expand Down Expand Up @@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None):
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow_numpy"

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
# creates inheritance problems (Diamond inheritance)
if item in ArrowStringArrayMixin.__dict__ and item != "_pa_array":
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False
try:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(dtype), # type: ignore[arg-type]
)
return result

except ValueError:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
)
if convert and result.dtype == object:
result = lib.maybe_convert_objects(result)
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat).to_numpy()
return self._convert_int_dtype(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array).to_numpy()
return self._convert_int_dtype(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result.to_numpy())

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True):
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)
2 changes: 1 addition & 1 deletion pandas/core/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def use_inf_as_na_cb(key) -> None:
"string_storage",
"python",
string_storage_doc,
validator=is_one_of_factory(["python", "pyarrow"]),
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]),
)


Expand Down
4 changes: 3 additions & 1 deletion pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def _map_and_wrap(name: str | None, docstring: str | None):
@forbid_nonstring_types(["bytes"], name=name)
def wrapper(self):
result = getattr(self._data.array, f"_str_{name}")()
return self._wrap_result(result)
return self._wrap_result(
result, returns_string=name not in ("isnumeric", "isdecimal")
)

wrapper.__doc__ = docstring
return wrapper
Expand Down
Loading

0 comments on commit 2ad36cc

Please sign in to comment.