diff --git a/doc/source/reference/series.rst b/doc/source/reference/series.rst index 58351bab07b225..9acbab7a42800b 100644 --- a/doc/source/reference/series.rst +++ b/doc/source/reference/series.rst @@ -525,6 +525,29 @@ Sparse-dtype specific methods and attributes are provided under the Series.sparse.from_coo Series.sparse.to_coo + +.. _api.series.struct: + +Struct accessor +~~~~~~~~~~~~~~~ + +Arrow struct-dtype specific methods and attributes are provided under the +``Series.struct`` accessor. + +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_attribute.rst + + Series.struct.dtypes + +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.struct.field + Series.struct.explode + + .. _api.series.flags: Flags diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index a795514aa31f8e..098632bdd99778 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -14,10 +14,35 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ -.. _whatsnew_220.enhancements.enhancement1: +.. _whatsnew_220.enhancements.struct_accessor: -enhancement1 -^^^^^^^^^^^^ +Series.struct accessor to with PyArrow structured data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``Series.struct`` accessor provides attributes and methods for processing +data with ``struct[pyarrow]`` dtype Series. For example, +:meth:`Series.struct.explode` converts PyArrow structured data to a pandas +DataFrame. (:issue:`54938`) + +.. code-block:: ipython + + In [1]: import pyarrow as pa + ...: struct_type = pa.struct([ + ...: ("int_col", pa.int64()), + ...: ("string_col", pa.string()), + ...: ]) + ...: struct_array = pa.array([ + ...: {"int_col": 1, "string_col": "a"}, + ...: {"int_col": 2, "string_col": "b"}, + ...: {"int_col": 3, "string_col": "c"}, + ...: ], type=struct_type) + ...: series = pd.Series(struct_array, dtype=pd.ArrowDtype(struct_type)) + In [2]: series.struct.explode() + Out[2]: + int_col string_col + 0 1 a + 1 2 b + 2 3 c .. _whatsnew_220.enhancements.enhancement2: diff --git a/pandas/core/arrays/arrow/__init__.py b/pandas/core/arrays/arrow/__init__.py index 58b268cbdd2217..a3d33f91f597d0 100644 --- a/pandas/core/arrays/arrow/__init__.py +++ b/pandas/core/arrays/arrow/__init__.py @@ -1,3 +1,4 @@ +from pandas.core.arrays.arrow.accessors import StructAccessor from pandas.core.arrays.arrow.array import ArrowExtensionArray -__all__ = ["ArrowExtensionArray"] +__all__ = ["ArrowExtensionArray", "StructAccessor"] diff --git a/pandas/core/arrays/arrow/accessors.py b/pandas/core/arrays/arrow/accessors.py new file mode 100644 index 00000000000000..1e3d36c95782c4 --- /dev/null +++ b/pandas/core/arrays/arrow/accessors.py @@ -0,0 +1,195 @@ +"""Accessors for arrow-backed data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas.compat import pa_version_under7p0 + +if not pa_version_under7p0: + import pyarrow as pa + import pyarrow.compute as pc + + from pandas.core.dtypes.dtypes import ArrowDtype + +if TYPE_CHECKING: + from pandas import ( + DataFrame, + Series, + ) + + +class StructAccessor: + """ + Accessor object for structured data properties of the Series values. + + Parameters + ---------- + data : Series + Series containing Arrow struct data. + """ + + _validation_msg = ( + "Can only use the '.struct' accessor with 'struct[pyarrow]' dtype, not {dtype}." + ) + + def __init__(self, data=None) -> None: + self._parent = data + self._validate(data) + + def _validate(self, data): + dtype = data.dtype + if not isinstance(dtype, ArrowDtype): + # Raise AttributeError so that inspect can handle non-struct Series. + raise AttributeError(self._validation_msg.format(dtype=dtype)) + + if not pa.types.is_struct(dtype.pyarrow_dtype): + # Raise AttributeError so that inspect can handle non-struct Series. + raise AttributeError(self._validation_msg.format(dtype=dtype)) + + @property + def dtypes(self) -> Series: + """ + Return the dtype object of each child field of the struct. + + Returns + ------- + pandas.Series + The data type of each child field. + + Examples + -------- + >>> import pyarrow as pa + >>> s = pd.Series( + ... [ + ... {"version": 1, "project": "pandas"}, + ... {"version": 2, "project": "pandas"}, + ... {"version": 1, "project": "numpy"}, + ... ], + ... dtype=pd.ArrowDtype(pa.struct( + ... [("version", pa.int64()), ("project", pa.string())] + ... )) + ... ) + >>> s.struct.dtypes + version int64[pyarrow] + project string[pyarrow] + dtype: object + """ + from pandas import ( + Index, + Series, + ) + + pa_type = self._parent.dtype.pyarrow_dtype + types = [ArrowDtype(struct.type) for struct in pa_type] + names = [struct.name for struct in pa_type] + return Series(types, index=Index(names)) + + def field(self, name_or_index: str | int) -> Series: + """ + Extract a child field of a struct as a Series. + + Parameters + ---------- + name_or_index : str | int + Name or index of the child field to extract. + + Returns + ------- + pandas.Series + The data corresponding to the selected child field. + + See Also + -------- + Series.struct.explode : Return all child fields as a DataFrame. + + Examples + -------- + >>> import pyarrow as pa + >>> s = pd.Series( + ... [ + ... {"version": 1, "project": "pandas"}, + ... {"version": 2, "project": "pandas"}, + ... {"version": 1, "project": "numpy"}, + ... ], + ... dtype=pd.ArrowDtype(pa.struct( + ... [("version", pa.int64()), ("project", pa.string())] + ... )) + ... ) + + Extract by field name. + + >>> s.struct.field("project") + 0 pandas + 1 pandas + 2 numpy + Name: project, dtype: string[pyarrow] + + Extract by field index. + + >>> s.struct.field(0) + 0 1 + 1 2 + 2 1 + Name: version, dtype: int64[pyarrow] + """ + from pandas import Series + + pa_arr = self._parent.array._pa_array + if isinstance(name_or_index, int): + index = name_or_index + elif isinstance(name_or_index, str): + index = pa_arr.type.get_field_index(name_or_index) + else: + raise ValueError( + f"name_or_index must be an int or str, got {type(name_or_index)}" + ) + + pa_field = pa_arr.type[index] + field_arr = pc.struct_field(pa_arr, [index]) + return Series( + field_arr, + dtype=ArrowDtype(field_arr.type), + index=self._parent.index, + name=pa_field.name, + ) + + def explode(self) -> DataFrame: + """ + Extract all child fields of a struct as a DataFrame. + + Returns + ------- + pandas.DataFrame + The data corresponding to all child fields. + + See Also + -------- + Series.struct.field : Return a single child field as a Series. + + Examples + -------- + >>> import pyarrow as pa + >>> s = pd.Series( + ... [ + ... {"version": 1, "project": "pandas"}, + ... {"version": 2, "project": "pandas"}, + ... {"version": 1, "project": "numpy"}, + ... ], + ... dtype=pd.ArrowDtype(pa.struct( + ... [("version", pa.int64()), ("project", pa.string())] + ... )) + ... ) + + >>> s.struct.explode() + version project + 0 1 pandas + 1 2 pandas + 2 1 numpy + """ + from pandas import concat + + pa_type = self._parent.dtype.pyarrow_dtype + return concat( + [self.field(i) for i in range(pa_type.num_fields)], axis="columns" + ) diff --git a/pandas/core/series.py b/pandas/core/series.py index 9b5c8829fd5ff2..e0e27581ef7e26 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -101,6 +101,7 @@ from pandas.core.accessor import CachedAccessor from pandas.core.apply import SeriesApply from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.arrow import StructAccessor from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor from pandas.core.construction import ( @@ -5787,6 +5788,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series cat = CachedAccessor("cat", CategoricalAccessor) plot = CachedAccessor("plot", pandas.plotting.PlotAccessor) sparse = CachedAccessor("sparse", SparseAccessor) + struct = CachedAccessor("struct", StructAccessor) # ---------------------------------------------------------------------- # Add plotting methods to Series diff --git a/pandas/tests/series/accessors/test_struct_accessor.py b/pandas/tests/series/accessors/test_struct_accessor.py new file mode 100644 index 00000000000000..c645bb68070521 --- /dev/null +++ b/pandas/tests/series/accessors/test_struct_accessor.py @@ -0,0 +1,147 @@ +import re + +import pytest + +from pandas import ( + ArrowDtype, + DataFrame, + Index, + Series, +) +import pandas._testing as tm +from pandas.core.arrays.arrow.accessors import StructAccessor + +pa = pytest.importorskip("pyarrow") + + +def test_struct_accessor_dtypes(): + ser = Series( + [], + dtype=ArrowDtype( + pa.struct( + [ + ("int_col", pa.int64()), + ("string_col", pa.string()), + ( + "struct_col", + pa.struct( + [ + ("int_col", pa.int64()), + ("float_col", pa.float64()), + ] + ), + ), + ] + ) + ), + ) + actual = ser.struct.dtypes + expected = Series( + [ + ArrowDtype(pa.int64()), + ArrowDtype(pa.string()), + ArrowDtype( + pa.struct( + [ + ("int_col", pa.int64()), + ("float_col", pa.float64()), + ] + ) + ), + ], + index=Index(["int_col", "string_col", "struct_col"]), + ) + tm.assert_series_equal(actual, expected) + + +def test_struct_accessor_field(): + index = Index([-100, 42, 123]) + ser = Series( + [ + {"rice": 1.0, "maize": -1, "wheat": "a"}, + {"rice": 2.0, "maize": 0, "wheat": "b"}, + {"rice": 3.0, "maize": 1, "wheat": "c"}, + ], + dtype=ArrowDtype( + pa.struct( + [ + ("rice", pa.float64()), + ("maize", pa.int64()), + ("wheat", pa.string()), + ] + ) + ), + index=index, + ) + by_name = ser.struct.field("maize") + by_name_expected = Series( + [-1, 0, 1], + dtype=ArrowDtype(pa.int64()), + index=index, + name="maize", + ) + tm.assert_series_equal(by_name, by_name_expected) + + by_index = ser.struct.field(2) + by_index_expected = Series( + ["a", "b", "c"], + dtype=ArrowDtype(pa.string()), + index=index, + name="wheat", + ) + tm.assert_series_equal(by_index, by_index_expected) + + +def test_struct_accessor_field_with_invalid_name_or_index(): + ser = Series([], dtype=ArrowDtype(pa.struct([("field", pa.int64())]))) + + with pytest.raises(ValueError, match="name_or_index must be an int or str"): + ser.struct.field(1.1) + + +def test_struct_accessor_explode(): + index = Index([-100, 42, 123]) + ser = Series( + [ + {"painted": 1, "snapping": {"sea": "green"}}, + {"painted": 2, "snapping": {"sea": "leatherback"}}, + {"painted": 3, "snapping": {"sea": "hawksbill"}}, + ], + dtype=ArrowDtype( + pa.struct( + [ + ("painted", pa.int64()), + ("snapping", pa.struct([("sea", pa.string())])), + ] + ) + ), + index=index, + ) + actual = ser.struct.explode() + expected = DataFrame( + { + "painted": Series([1, 2, 3], index=index, dtype=ArrowDtype(pa.int64())), + "snapping": Series( + [{"sea": "green"}, {"sea": "leatherback"}, {"sea": "hawksbill"}], + index=index, + dtype=ArrowDtype(pa.struct([("sea", pa.string())])), + ), + }, + ) + tm.assert_frame_equal(actual, expected) + + +@pytest.mark.parametrize( + "invalid", + [ + pytest.param(Series([1, 2, 3], dtype="int64"), id="int64"), + pytest.param( + Series(["a", "b", "c"], dtype="string[pyarrow]"), id="string-pyarrow" + ), + ], +) +def test_struct_accessor_api_for_invalid(invalid): + msg = re.escape(StructAccessor._validation_msg.format(dtype=invalid.dtype)) + + with pytest.raises(AttributeError, match=msg): + invalid.struct