From 6b3e66b53d10e19e83b51f14a66a335d5ff3394b Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Mon, 18 Sep 2023 16:37:33 -0500 Subject: [PATCH] ENH: add `ExtensionArray._explode` method; adjust pyarrow extension for use of new interface (#54834) * add ExtensionArray._explode method; adjust pyarrow extension for use * black * add to whatsnew 2.1.0 * pre-commit fix * add _explode to docs * Update pandas/core/arrays/arrow/array.py Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> * switch whatsnew files * adjust docstring * fix docstring --------- Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> --- doc/source/reference/extensions.rst | 1 + doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 4 +++ pandas/core/arrays/base.py | 36 +++++++++++++++++++++ pandas/core/series.py | 7 ++-- pandas/tests/series/methods/test_explode.py | 10 ++++++ 6 files changed, 54 insertions(+), 5 deletions(-) diff --git a/doc/source/reference/extensions.rst b/doc/source/reference/extensions.rst index e177e2b1d87d5..83f830bb11198 100644 --- a/doc/source/reference/extensions.rst +++ b/doc/source/reference/extensions.rst @@ -34,6 +34,7 @@ objects. api.extensions.ExtensionArray._accumulate api.extensions.ExtensionArray._concat_same_type + api.extensions.ExtensionArray._explode api.extensions.ExtensionArray._formatter api.extensions.ExtensionArray._from_factorized api.extensions.ExtensionArray._from_sequence diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 55a3419e95703..0fc4afc95a2ce 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -73,6 +73,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ +- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`) - DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a329c37c77449..e67b7035822cc 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1609,6 +1609,10 @@ def _explode(self): """ See Series.explode.__doc__. """ + # child class explode method supports only list types; return + # default implementation for non list types. + if not pa.types.is_list(self.dtype.pyarrow_dtype): + return super()._explode() values = self counts = pa.compute.list_value_length(values._pa_array) counts = counts.fill_null(1).to_numpy() diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index f3bb7323c7d5f..933944dbd4632 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -142,6 +142,7 @@ class ExtensionArray: view _accumulate _concat_same_type + _explode _formatter _from_factorized _from_sequence @@ -1924,6 +1925,41 @@ def _hash_pandas_object( values, encoding=encoding, hash_key=hash_key, categorize=categorize ) + def _explode(self) -> tuple[Self, npt.NDArray[np.uint64]]: + """ + Transform each element of list-like to a row. + + For arrays that do not contain list-like elements the default + implementation of this method just returns a copy and an array + of ones (unchanged index). + + Returns + ------- + ExtensionArray + Array with the exploded values. + np.ndarray[uint64] + The original lengths of each list-like for determining the + resulting index. + + See Also + -------- + Series.explode : The method on the ``Series`` object that this + extension array method is meant to support. + + Examples + -------- + >>> import pyarrow as pa + >>> a = pd.array([[1, 2, 3], [4], [5, 6]], + ... dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + >>> a._explode() + ( + [1, 2, 3, 4, 5, 6] + Length: 6, dtype: int64[pyarrow], array([3, 1, 2], dtype=int32)) + """ + values = self.copy() + counts = np.ones(shape=(len(self),), dtype=np.uint64) + return values, counts + def tolist(self) -> list: """ Return a list of the values. diff --git a/pandas/core/series.py b/pandas/core/series.py index e0e27581ef7e2..78ec1554198df 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -76,10 +76,7 @@ pandas_dtype, validate_all_hashable, ) -from pandas.core.dtypes.dtypes import ( - ArrowDtype, - ExtensionDtype, -) +from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import ( @@ -4390,7 +4387,7 @@ def explode(self, ignore_index: bool = False) -> Series: 3 4 dtype: object """ - if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list: + if isinstance(self.dtype, ExtensionDtype): values, counts = self._values._explode() elif len(self) and is_object_dtype(self.dtype): values, counts = reshape.explode(np.asarray(self._values)) diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index c8a9eb6f89fde..5a0188585ef30 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -163,3 +163,13 @@ def test_explode_pyarrow_list_type(ignore_index): dtype=pd.ArrowDtype(pa.int64()), ) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ignore_index", [True, False]) +def test_explode_pyarrow_non_list_type(ignore_index): + pa = pytest.importorskip("pyarrow") + data = [1, 2, 3] + ser = pd.Series(data, dtype=pd.ArrowDtype(pa.int64())) + result = ser.explode(ignore_index=ignore_index) + expected = pd.Series([1, 2, 3], dtype="int64[pyarrow]", index=[0, 1, 2]) + tm.assert_series_equal(result, expected)