-
-
Notifications
You must be signed in to change notification settings - Fork 18.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: add Series.struct accessor for ArrowDtype[struct]
Features: * Series.struct.dtypes -- see dtypes and field names * Series.struct.field(name_or_index) -- extract a field as a Series * Series.struct.to_frame() -- convert all fields into a DataFrame
- Loading branch information
Showing
6 changed files
with
356 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from pandas.core.arrays.arrow.accessors import StructAccessor | ||
from pandas.core.arrays.arrow.array import ArrowExtensionArray | ||
|
||
__all__ = ["ArrowExtensionArray"] | ||
__all__ = ["ArrowExtensionArray", "StructAccessor"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
"""Accessors for arrow-backed data.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from pandas.compat import pa_version_under7p0 | ||
|
||
if not pa_version_under7p0: | ||
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
|
||
from pandas.core.dtypes.dtypes import ArrowDtype | ||
|
||
if TYPE_CHECKING: | ||
from pandas import ( | ||
DataFrame, | ||
Series, | ||
) | ||
|
||
|
||
class StructAccessor: | ||
""" | ||
Accessor object for structured data properties of the Series values. | ||
Parameters | ||
---------- | ||
data : Series | ||
Series containing Arrow struct data. | ||
""" | ||
|
||
_validation_msg = "Can only use the '.struct' accessor with 'struct[pyarrow]' data." | ||
|
||
def __init__(self, data=None) -> None: | ||
self._parent = data | ||
self._validate(data) | ||
|
||
def _validate(self, data): | ||
dtype = data.dtype | ||
if not isinstance(dtype, ArrowDtype): | ||
raise AttributeError(self._validation_msg) | ||
|
||
if not pa.types.is_struct(dtype.pyarrow_dtype): | ||
raise AttributeError(self._validation_msg) | ||
|
||
@property | ||
def dtypes(self) -> Series: | ||
""" | ||
Return the dtype object of each child field of the struct. | ||
Returns | ||
------- | ||
pandas.Series | ||
The data type of each child field. | ||
Examples | ||
-------- | ||
>>> import pyarrow as pa | ||
>>> s = pd.Series( | ||
... [ | ||
... {"version": 1, "project": "pandas"}, | ||
... {"version": 2, "project": "pandas"}, | ||
... {"version": 1, "project": "numpy"}, | ||
... ], | ||
... dtype=pd.ArrowDtype(pa.struct( | ||
... [("version", pa.int64()), ("project", pa.string())] | ||
... )) | ||
... ) | ||
>>> s.struct.dtypes | ||
version int64[pyarrow] | ||
project string[pyarrow] | ||
dtype: object | ||
""" | ||
from pandas import ( | ||
Index, | ||
Series, | ||
) | ||
|
||
pa_type = self._parent.dtype.pyarrow_dtype | ||
types = [ArrowDtype(pa_type[i].type) for i in range(pa_type.num_fields)] | ||
names = [pa_type[i].name for i in range(pa_type.num_fields)] | ||
return Series(types, index=Index(names)) | ||
|
||
def field(self, name_or_index: str | int) -> Series: | ||
""" | ||
Extract a child field of a struct as a Series. | ||
Parameters | ||
---------- | ||
name_or_index : str | int | ||
Name or index of the child field to extract. | ||
Returns | ||
------- | ||
pandas.Series | ||
The data corresponding to the selected child field. | ||
See Also | ||
-------- | ||
Series.struct.to_frame : Return all child fields as a DataFrame. | ||
Examples | ||
-------- | ||
>>> import pyarrow as pa | ||
>>> s = pd.Series( | ||
... [ | ||
... {"version": 1, "project": "pandas"}, | ||
... {"version": 2, "project": "pandas"}, | ||
... {"version": 1, "project": "numpy"}, | ||
... ], | ||
... dtype=pd.ArrowDtype(pa.struct( | ||
... [("version", pa.int64()), ("project", pa.string())] | ||
... )) | ||
... ) | ||
Extract by field name. | ||
>>> s.struct.field("project") | ||
0 pandas | ||
1 pandas | ||
2 numpy | ||
Name: project, dtype: string[pyarrow] | ||
Extract by field index. | ||
>>> s.struct.field(0) | ||
0 1 | ||
1 2 | ||
2 1 | ||
Name: version, dtype: int64[pyarrow] | ||
""" | ||
from pandas import Series | ||
|
||
pa_arr = self._parent.array._pa_array | ||
if isinstance(name_or_index, int): | ||
index = name_or_index | ||
else: | ||
index = pa_arr.type.get_field_index(name_or_index) | ||
|
||
pa_field = pa_arr.type[index] | ||
field_arr = pc.struct_field(pa_arr, [index]) | ||
return Series( | ||
field_arr, | ||
dtype=ArrowDtype(field_arr.type), | ||
index=self._parent.index, | ||
name=pa_field.name, | ||
) | ||
|
||
def to_frame(self) -> DataFrame: | ||
""" | ||
Extract all child fields of a struct as a DataFrame. | ||
Returns | ||
------- | ||
pandas.DataFrame | ||
The data corresponding to all child fields. | ||
See Also | ||
-------- | ||
Series.struct.field : Return a single child field as a Series. | ||
Examples | ||
-------- | ||
>>> import pyarrow as pa | ||
>>> s = pd.Series( | ||
... [ | ||
... {"version": 1, "project": "pandas"}, | ||
... {"version": 2, "project": "pandas"}, | ||
... {"version": 1, "project": "numpy"}, | ||
... ], | ||
... dtype=pd.ArrowDtype(pa.struct( | ||
... [("version", pa.int64()), ("project", pa.string())] | ||
... )) | ||
... ) | ||
>>> s.struct.to_frame() | ||
version project | ||
0 1 pandas | ||
1 2 pandas | ||
2 1 numpy | ||
""" | ||
from pandas import concat | ||
|
||
pa_type = self._parent.dtype.pyarrow_dtype | ||
return concat( | ||
[self.field(i) for i in range(pa_type.num_fields)], axis="columns" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import re | ||
|
||
import pytest | ||
|
||
from pandas import ( | ||
ArrowDtype, | ||
DataFrame, | ||
Index, | ||
Series, | ||
) | ||
import pandas._testing as tm | ||
|
||
pa = pytest.importorskip("pyarrow") | ||
|
||
|
||
class TestStructAccessor: | ||
def test_struct_accessor_dtypes(self): | ||
ser = Series( | ||
[], | ||
dtype=ArrowDtype( | ||
pa.struct([("int_col", pa.int64()), ("string_col", pa.string())]) | ||
), | ||
) | ||
actual = ser.struct.dtypes | ||
expected = Series( | ||
[ArrowDtype(pa.int64()), ArrowDtype(pa.string())], | ||
index=Index(["int_col", "string_col"]), | ||
) | ||
tm.assert_series_equal(actual, expected) | ||
|
||
def test_struct_accessor_field(self): | ||
index = Index([-100, 42, 123]) | ||
ser = Series( | ||
[ | ||
{"rice": 1.0, "maize": -1, "wheat": "a"}, | ||
{"rice": 2.0, "maize": 0, "wheat": "b"}, | ||
{"rice": 3.0, "maize": 1, "wheat": "c"}, | ||
], | ||
dtype=ArrowDtype( | ||
pa.struct( | ||
[ | ||
("rice", pa.float64()), | ||
("maize", pa.int64()), | ||
("wheat", pa.string()), | ||
] | ||
) | ||
), | ||
index=index, | ||
) | ||
by_name = ser.struct.field("maize") | ||
by_name_expected = Series( | ||
[-1, 0, 1], | ||
dtype=ArrowDtype(pa.int64()), | ||
index=index, | ||
name="maize", | ||
) | ||
tm.assert_series_equal(by_name, by_name_expected) | ||
|
||
by_index = ser.struct.field(2) | ||
by_index_expected = Series( | ||
["a", "b", "c"], | ||
dtype=ArrowDtype(pa.string()), | ||
index=index, | ||
name="wheat", | ||
) | ||
tm.assert_series_equal(by_index, by_index_expected) | ||
|
||
def test_struct_accessor_to_frame(self): | ||
index = Index([-100, 42, 123]) | ||
ser = Series( | ||
[ | ||
{"painted": 1, "snapping": {"sea": "green"}}, | ||
{"painted": 2, "snapping": {"sea": "leatherback"}}, | ||
{"painted": 3, "snapping": {"sea": "hawksbill"}}, | ||
], | ||
dtype=ArrowDtype( | ||
pa.struct( | ||
[ | ||
("painted", pa.int64()), | ||
("snapping", pa.struct([("sea", pa.string())])), | ||
] | ||
) | ||
), | ||
index=index, | ||
) | ||
actual = ser.struct.to_frame() | ||
expected = DataFrame( | ||
{ | ||
"painted": Series([1, 2, 3], index=index, dtype=ArrowDtype(pa.int64())), | ||
"snapping": Series( | ||
[{"sea": "green"}, {"sea": "leatherback"}, {"sea": "hawksbill"}], | ||
index=index, | ||
dtype=ArrowDtype(pa.struct([("sea", pa.string())])), | ||
), | ||
}, | ||
) | ||
tm.assert_frame_equal(actual, expected) | ||
|
||
@pytest.mark.parametrize( | ||
"invalid", | ||
[ | ||
pytest.param(Series([1, 2, 3], dtype="int64"), id="int64"), | ||
pytest.param( | ||
Series(["a", "b", "c"], dtype="string[pyarrow]"), id="string-pyarrow" | ||
), | ||
], | ||
) | ||
def test_struct_accessor_api_for_invalid(self, invalid): | ||
msg = re.escape( | ||
"Can only use the '.struct' accessor with 'struct[pyarrow]' data." | ||
) | ||
|
||
with pytest.raises(AttributeError, match=msg): | ||
invalid.struct |