Skip to content

Commit

Permalink
ENH: add ExtensionArray._explode method; adjust pyarrow extension f…
Browse files Browse the repository at this point in the history
…or use of new interface (#54834)

* add ExtensionArray._explode method; adjust pyarrow extension for use

* black

* add to whatsnew 2.1.0

* pre-commit fix

* add _explode to docs

* Update pandas/core/arrays/arrow/array.py

Co-authored-by: Matthew Roeschke <[email protected]>

* switch whatsnew files

* adjust docstring

* fix docstring

---------

Co-authored-by: Matthew Roeschke <[email protected]>
  • Loading branch information
douglasdavis and mroeschke authored Sep 18, 2023
1 parent 6826845 commit 6b3e66b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/source/reference/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ objects.
api.extensions.ExtensionArray._accumulate
api.extensions.ExtensionArray._concat_same_type
api.extensions.ExtensionArray._explode
api.extensions.ExtensionArray._formatter
api.extensions.ExtensionArray._from_factorized
api.extensions.ExtensionArray._from_sequence
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
-

Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,10 @@ def _explode(self):
"""
See Series.explode.__doc__.
"""
# child class explode method supports only list types; return
# default implementation for non list types.
if not pa.types.is_list(self.dtype.pyarrow_dtype):
return super()._explode()
values = self
counts = pa.compute.list_value_length(values._pa_array)
counts = counts.fill_null(1).to_numpy()
Expand Down
36 changes: 36 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ExtensionArray:
view
_accumulate
_concat_same_type
_explode
_formatter
_from_factorized
_from_sequence
Expand Down Expand Up @@ -1924,6 +1925,41 @@ def _hash_pandas_object(
values, encoding=encoding, hash_key=hash_key, categorize=categorize
)

def _explode(self) -> tuple[Self, npt.NDArray[np.uint64]]:
"""
Transform each element of list-like to a row.
For arrays that do not contain list-like elements the default
implementation of this method just returns a copy and an array
of ones (unchanged index).
Returns
-------
ExtensionArray
Array with the exploded values.
np.ndarray[uint64]
The original lengths of each list-like for determining the
resulting index.
See Also
--------
Series.explode : The method on the ``Series`` object that this
extension array method is meant to support.
Examples
--------
>>> import pyarrow as pa
>>> a = pd.array([[1, 2, 3], [4], [5, 6]],
... dtype=pd.ArrowDtype(pa.list_(pa.int64())))
>>> a._explode()
(<ArrowExtensionArray>
[1, 2, 3, 4, 5, 6]
Length: 6, dtype: int64[pyarrow], array([3, 1, 2], dtype=int32))
"""
values = self.copy()
counts = np.ones(shape=(len(self),), dtype=np.uint64)
return values, counts

def tolist(self) -> list:
"""
Return a list of the values.
Expand Down
7 changes: 2 additions & 5 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
ExtensionDtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -4390,7 +4387,7 @@ def explode(self, ignore_index: bool = False) -> Series:
3 4
dtype: object
"""
if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list:
if isinstance(self.dtype, ExtensionDtype):
values, counts = self._values._explode()
elif len(self) and is_object_dtype(self.dtype):
values, counts = reshape.explode(np.asarray(self._values))
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/series/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,13 @@ def test_explode_pyarrow_list_type(ignore_index):
dtype=pd.ArrowDtype(pa.int64()),
)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("ignore_index", [True, False])
def test_explode_pyarrow_non_list_type(ignore_index):
pa = pytest.importorskip("pyarrow")
data = [1, 2, 3]
ser = pd.Series(data, dtype=pd.ArrowDtype(pa.int64()))
result = ser.explode(ignore_index=ignore_index)
expected = pd.Series([1, 2, 3], dtype="int64[pyarrow]", index=[0, 1, 2])
tm.assert_series_equal(result, expected)

0 comments on commit 6b3e66b

Please sign in to comment.