Skip to content

Commit

Permalink
ENH: add Series.struct accessor for ArrowDtype[struct]
Browse files Browse the repository at this point in the history
Features:

* Series.struct.dtypes -- see dtypes and field names
* Series.struct.field(name_or_index) -- extract a field as a Series
* Series.struct.to_frame() -- convert all fields into a DataFrame
  • Loading branch information
tswast committed Sep 3, 2023
1 parent 53243e8 commit 66ff669
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pandas/core/arrays/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandas.core.arrays.arrow.accessors import StructAccessor
from pandas.core.arrays.arrow.array import ArrowExtensionArray

__all__ = ["ArrowExtensionArray"]
__all__ = ["ArrowExtensionArray", "StructAccessor"]
66 changes: 66 additions & 0 deletions pandas/core/arrays/arrow/accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from pandas.compat import pa_version_under7p0

if not pa_version_under7p0:
import pyarrow as pa
import pyarrow.compute as pc

from pandas.core.dtypes.dtypes import ArrowDtype

if TYPE_CHECKING:
from pandas import (
DataFrame,
Series,
)


class StructAccessor:
_validation_msg = "Can only use the '.struct' accessor with 'struct[pyarrow]' data."

def __init__(self, data=None) -> None:
self._parent = data
self._validate(data)

def _validate(self, data):
dtype = data.dtype
if not isinstance(dtype, ArrowDtype):
raise AttributeError(self._validation_message)

if not pa.types.is_struct(dtype.pyarrow_dtype):
raise AttributeError(self._validation_message)

@property
def dtypes(self) -> Series:
from pandas import (
Index,
Series,
)

pa_type = self._parent.dtype.pyarrow_dtype
types = [ArrowDtype(pa_type[i].type) for i in range(pa_type.num_fields)]
names = [pa_type[i].name for i in range(pa_type.num_fields)]
return Series(types, index=Index(names))

def field(self, name_or_index: str | int) -> Series:
from pandas import Series

pa_arr = self._parent.array._pa_array
if isinstance(name_or_index, int):
index = name_or_index
else:
index = pa_arr.type.get_field_index(name_or_index)

pa_field = pa_arr.type[index]
field_arr = pc.struct_field(pa_arr, [index])
return Series(field_arr, dtype=ArrowDtype(field_arr.type), name=pa_field.name)

def to_frame(self) -> DataFrame:
from pandas import concat

pa_type = self._parent.dtype.pyarrow_dtype
return concat(
[self.field(i) for i in range(pa_type.num_fields)], axis="columns"
)
2 changes: 2 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from pandas.core.accessor import CachedAccessor
from pandas.core.apply import SeriesApply
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.arrow import StructAccessor
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.construction import (
Expand Down Expand Up @@ -5787,6 +5788,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series
cat = CachedAccessor("cat", CategoricalAccessor)
plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
sparse = CachedAccessor("sparse", SparseAccessor)
struct = CachedAccessor("struct", StructAccessor)

# ----------------------------------------------------------------------
# Add plotting methods to Series
Expand Down

0 comments on commit 66ff669

Please sign in to comment.