Skip to content

Commit

Permalink
BUG: Fix ListAccessor methods to preserve original name (#60527)
Browse files Browse the repository at this point in the history
* fix: preserve series name in ListAccessor

* formatting

* add whatsnew v3.0.0 entry
  • Loading branch information
FBruzzesi authored Dec 9, 2024
1 parent 59f947f commit 05f7ef9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ Other
- Bug in :meth:`read_csv` where chained fsspec TAR file and ``compression="infer"`` fails with ``tarfile.ReadError`` (:issue:`60028`)
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)
- Bug in ``Series.list`` methods not preserving the original :class:`Index`. (:issue:`58425`)
- Bug in ``Series.list`` methods not preserving the original name. (:issue:`60522`)
- Bug in printing a :class:`DataFrame` with a :class:`DataFrame` stored in :attr:`DataFrame.attrs` raised a ``ValueError`` (:issue:`60455`)

.. ***DO NOT USE THIS SECTION***
Expand Down
24 changes: 20 additions & 4 deletions pandas/core/arrays/arrow/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def len(self) -> Series:

value_lengths = pc.list_value_length(self._pa_array)
return Series(
value_lengths, dtype=ArrowDtype(value_lengths.type), index=self._data.index
value_lengths,
dtype=ArrowDtype(value_lengths.type),
index=self._data.index,
name=self._data.name,
)

def __getitem__(self, key: int | slice) -> Series:
Expand Down Expand Up @@ -162,7 +165,10 @@ def __getitem__(self, key: int | slice) -> Series:
# key = pc.add(key, pc.list_value_length(self._pa_array))
element = pc.list_element(self._pa_array, key)
return Series(
element, dtype=ArrowDtype(element.type), index=self._data.index
element,
dtype=ArrowDtype(element.type),
index=self._data.index,
name=self._data.name,
)
elif isinstance(key, slice):
if pa_version_under11p0:
Expand All @@ -181,7 +187,12 @@ def __getitem__(self, key: int | slice) -> Series:
if step is None:
step = 1
sliced = pc.list_slice(self._pa_array, start, stop, step)
return Series(sliced, dtype=ArrowDtype(sliced.type), index=self._data.index)
return Series(
sliced,
dtype=ArrowDtype(sliced.type),
index=self._data.index,
name=self._data.name,
)
else:
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")

Expand Down Expand Up @@ -223,7 +234,12 @@ def flatten(self) -> Series:
counts = pa.compute.list_value_length(self._pa_array)
flattened = pa.compute.list_flatten(self._pa_array)
index = self._data.index.repeat(counts.fill_null(pa.scalar(0, counts.type)))
return Series(flattened, dtype=ArrowDtype(flattened.type), index=index)
return Series(
flattened,
dtype=ArrowDtype(flattened.type),
index=index,
name=self._data.name,
)


class StructAccessor(ArrowAccessor):
Expand Down
18 changes: 15 additions & 3 deletions pandas/tests/series/accessors/test_list_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def test_list_getitem(list_dtype):
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(list_dtype),
name="a",
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]")
expected = Series([2, None, None], dtype="int64[pyarrow]", name="a")
tm.assert_series_equal(actual, expected)


Expand All @@ -37,9 +38,15 @@ def test_list_getitem_index():
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]", index=[1, 3, 7])
expected = Series(
[2, None, None],
dtype="int64[pyarrow]",
index=[1, 3, 7],
name="a",
)
tm.assert_series_equal(actual, expected)


Expand All @@ -48,6 +55,7 @@ def test_list_getitem_slice():
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
if pa_version_under11p0:
with pytest.raises(
Expand All @@ -60,6 +68,7 @@ def test_list_getitem_slice():
[[2, 3], [None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
tm.assert_series_equal(actual, expected)

Expand All @@ -68,22 +77,25 @@ def test_list_len():
ser = Series(
[[1, 2, 3], [4, None], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
name="a",
)
actual = ser.list.len()
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()))
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()), name="a")
tm.assert_series_equal(actual, expected)


def test_list_flatten():
ser = Series(
[[1, 2, 3], None, [4, None], [], [7, 8]],
dtype=ArrowDtype(pa.list_(pa.int64())),
name="a",
)
actual = ser.list.flatten()
expected = Series(
[1, 2, 3, 4, None, 7, 8],
dtype=ArrowDtype(pa.int64()),
index=[0, 0, 0, 2, 2, 4, 4],
name="a",
)
tm.assert_series_equal(actual, expected)

Expand Down

0 comments on commit 05f7ef9

Please sign in to comment.