Skip to content

Commit

Permalink
Backport PR pandas-dev#56167: [ENH]: Expand types allowed in Series.s…
Browse files Browse the repository at this point in the history
…truct.field
  • Loading branch information
TomAugspurger authored and meeseeksmachine committed Jan 2, 2024
1 parent 0d0c792 commit 4fbd41c
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 16 deletions.
8 changes: 8 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ DataFrame. (:issue:`54938`)
)
series.struct.explode()
Use :meth:`Series.struct.field` to index into a (possible nested)
struct field.


.. ipython:: python
series.struct.field("project")
.. _whatsnew_220.enhancements.list_accessor:

Series.list accessor for PyArrow list data
Expand Down
125 changes: 110 additions & 15 deletions pandas/core/arrays/arrow/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
ABCMeta,
abstractmethod,
)
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
cast,
)

from pandas.compat import (
pa_version_under10p1,
pa_version_under11p0,
)

from pandas.core.dtypes.common import is_list_like

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -267,15 +272,27 @@ def dtypes(self) -> Series:
names = [struct.name for struct in pa_type]
return Series(types, index=Index(names))

def field(self, name_or_index: str | int) -> Series:
def field(
self,
name_or_index: list[str]
| list[bytes]
| list[int]
| pc.Expression
| bytes
| str
| int,
) -> Series:
"""
Extract a child field of a struct as a Series.
Parameters
----------
name_or_index : str | int
name_or_index : str | bytes | int | expression | list
Name or index of the child field to extract.
For list-like inputs, this will index into a nested
struct.
Returns
-------
pandas.Series
Expand All @@ -285,6 +302,19 @@ def field(self, name_or_index: str | int) -> Series:
--------
Series.struct.explode : Return all child fields as a DataFrame.
Notes
-----
The name of the resulting Series will be set using the following
rules:
- For string, bytes, or integer `name_or_index` (or a list of these, for
a nested selection), the Series name is set to the selected
field's name.
- For a :class:`pyarrow.compute.Expression`, this is set to
the string form of the expression.
- For list-like `name_or_index`, the name will be set to the
name of the final field selected.
Examples
--------
>>> import pyarrow as pa
Expand Down Expand Up @@ -314,27 +344,92 @@ def field(self, name_or_index: str | int) -> Series:
1 2
2 1
Name: version, dtype: int64[pyarrow]
Or an expression
>>> import pyarrow.compute as pc
>>> s.struct.field(pc.field("project"))
0 pandas
1 pandas
2 numpy
Name: project, dtype: string[pyarrow]
For nested struct types, you can pass a list of values to index
multiple levels:
>>> version_type = pa.struct([
... ("major", pa.int64()),
... ("minor", pa.int64()),
... ])
>>> s = pd.Series(
... [
... {"version": {"major": 1, "minor": 5}, "project": "pandas"},
... {"version": {"major": 2, "minor": 1}, "project": "pandas"},
... {"version": {"major": 1, "minor": 26}, "project": "numpy"},
... ],
... dtype=pd.ArrowDtype(pa.struct(
... [("version", version_type), ("project", pa.string())]
... ))
... )
>>> s.struct.field(["version", "minor"])
0 5
1 1
2 26
Name: minor, dtype: int64[pyarrow]
>>> s.struct.field([0, 0])
0 1
1 2
2 1
Name: major, dtype: int64[pyarrow]
"""
from pandas import Series

def get_name(
level_name_or_index: list[str]
| list[bytes]
| list[int]
| pc.Expression
| bytes
| str
| int,
data: pa.ChunkedArray,
):
if isinstance(level_name_or_index, int):
name = data.type.field(level_name_or_index).name
elif isinstance(level_name_or_index, (str, bytes)):
name = level_name_or_index
elif isinstance(level_name_or_index, pc.Expression):
name = str(level_name_or_index)
elif is_list_like(level_name_or_index):
# For nested input like [2, 1, 2]
# iteratively get the struct and field name. The last
# one is used for the name of the index.
level_name_or_index = list(reversed(level_name_or_index))
selected = data
while level_name_or_index:
# we need the cast, otherwise mypy complains about
# getting ints, bytes, or str here, which isn't possible.
level_name_or_index = cast(list, level_name_or_index)
name_or_index = level_name_or_index.pop()
name = get_name(name_or_index, selected)
selected = selected.type.field(selected.type.get_field_index(name))
name = selected.name
else:
raise ValueError(
"name_or_index must be an int, str, bytes, "
"pyarrow.compute.Expression, or list of those"
)
return name

pa_arr = self._data.array._pa_array
if isinstance(name_or_index, int):
index = name_or_index
elif isinstance(name_or_index, str):
index = pa_arr.type.get_field_index(name_or_index)
else:
raise ValueError(
"name_or_index must be an int or str, "
f"got {type(name_or_index).__name__}"
)
name = get_name(name_or_index, pa_arr)
field_arr = pc.struct_field(pa_arr, name_or_index)

pa_field = pa_arr.type[index]
field_arr = pc.struct_field(pa_arr, [index])
return Series(
field_arr,
dtype=ArrowDtype(field_arr.type),
index=self._data.index,
name=pa_field.name,
name=name,
)

def explode(self) -> DataFrame:
Expand Down
48 changes: 47 additions & 1 deletion pandas/tests/series/accessors/test_struct_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import pytest

from pandas.compat.pyarrow import (
pa_version_under11p0,
pa_version_under13p0,
)

from pandas import (
ArrowDtype,
DataFrame,
Expand All @@ -11,6 +16,7 @@
import pandas._testing as tm

pa = pytest.importorskip("pyarrow")
pc = pytest.importorskip("pyarrow.compute")


def test_struct_accessor_dtypes():
Expand Down Expand Up @@ -53,6 +59,7 @@ def test_struct_accessor_dtypes():
tm.assert_series_equal(actual, expected)


@pytest.mark.skipif(pa_version_under13p0, reason="pyarrow>=13.0.0 required")
def test_struct_accessor_field():
index = Index([-100, 42, 123])
ser = Series(
Expand Down Expand Up @@ -94,10 +101,11 @@ def test_struct_accessor_field():
def test_struct_accessor_field_with_invalid_name_or_index():
ser = Series([], dtype=ArrowDtype(pa.struct([("field", pa.int64())])))

with pytest.raises(ValueError, match="name_or_index must be an int or str"):
with pytest.raises(ValueError, match="name_or_index must be an int, str,"):
ser.struct.field(1.1)


@pytest.mark.skipif(pa_version_under11p0, reason="pyarrow>=11.0.0 required")
def test_struct_accessor_explode():
index = Index([-100, 42, 123])
ser = Series(
Expand Down Expand Up @@ -148,3 +156,41 @@ def test_struct_accessor_api_for_invalid(invalid):
),
):
invalid.struct


@pytest.mark.parametrize(
["indices", "name"],
[
(0, "int_col"),
([1, 2], "str_col"),
(pc.field("int_col"), "int_col"),
("int_col", "int_col"),
(b"string_col", b"string_col"),
([b"string_col"], "string_col"),
],
)
@pytest.mark.skipif(pa_version_under13p0, reason="pyarrow>=13.0.0 required")
def test_struct_accessor_field_expanded(indices, name):
arrow_type = pa.struct(
[
("int_col", pa.int64()),
(
"struct_col",
pa.struct(
[
("int_col", pa.int64()),
("float_col", pa.float64()),
("str_col", pa.string()),
]
),
),
(b"string_col", pa.string()),
]
)

data = pa.array([], type=arrow_type)
ser = Series(data, dtype=ArrowDtype(arrow_type))
expected = pc.struct_field(data, indices)
result = ser.struct.field(indices)
tm.assert_equal(result.array._pa_array.combine_chunks(), expected)
assert result.name == name

0 comments on commit 4fbd41c

Please sign in to comment.