Skip to content

Commit

Permalink
Backport PR #55619 on branch 2.1.x (BUG: Groupby not keeping string d…
Browse files Browse the repository at this point in the history
…type for empty objects) (#55705)

BUG: Groupby not keeping string dtype for empty objects (#55619)

* BUG: Groupby not keeping string dtype for empty objects

* Fix

---------

Co-authored-by: Thomas Li <[email protected]>
(cherry picked from commit 8afd868)

Co-authored-by: Patrick Hoefler <[email protected]>
  • Loading branch information
lithomas1 and phofl authored Dec 4, 2023
1 parent 15376f8 commit ea39f7b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Fixed regressions
Bug fixes
~~~~~~~~~
- Fixed bug in :class:`.DataFrameGroupBy` reductions not preserving object dtype when ``infer_string`` is set (:issue:`55620`)
- Fixed bug in :meth:`.DataFrameGroupBy.min()` and :meth:`.DataFrameGroupBy.max()` not preserving extension dtype for empty object (:issue:`55619`)
- Fixed bug in :meth:`.SeriesGroupBy.value_counts` returning incorrect dtype for string columns (:issue:`55627`)
- Fixed bug in :meth:`Categorical.equals` if other has arrow backed string dtype (:issue:`55364`)
- Fixed bug in :meth:`DataFrame.__setitem__` not inferring string dtype for zero-dimensional array with ``infer_string=True`` (:issue:`55366`)
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,9 @@ def _groupby_op(
# GH#43682
if isinstance(self.dtype, StringDtype):
# StringArray
if op.how not in ["any", "all"]:
# Fail early to avoid conversion to object
op._get_cython_function(op.kind, op.how, np.dtype(object), False)
npvalues = self.to_numpy(object, na_value=np.nan)
else:
raise NotImplementedError(
Expand Down
20 changes: 13 additions & 7 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.cast import (
maybe_cast_pointwise_result,
maybe_downcast_to_dtype,
Expand Down Expand Up @@ -837,10 +838,8 @@ def agg_series(
-------
np.ndarray or ExtensionArray
"""
# test_groupby_empty_with_category gets here with self.ngroups == 0
# and len(obj) > 0

if len(obj) > 0 and not isinstance(obj._values, np.ndarray):
if not isinstance(obj._values, np.ndarray):
# we can preserve a little bit more aggressively with EA dtype
# because maybe_cast_pointwise_result will do a try/except
# with _from_sequence. NB we are assuming here that _from_sequence
Expand All @@ -849,11 +848,18 @@ def agg_series(

result = self._aggregate_series_pure_python(obj, func)

npvalues = lib.maybe_convert_objects(result, try_float=False)
if preserve_dtype:
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
if len(obj) == 0 and len(result) == 0 and isinstance(obj.dtype, ExtensionDtype):
cls = obj.dtype.construct_array_type()
out = cls._from_sequence(result)

else:
out = npvalues
npvalues = lib.maybe_convert_objects(result, try_float=False)
if preserve_dtype:
out = maybe_cast_pointwise_result(
npvalues, obj.dtype, numeric_only=True
)
else:
out = npvalues
return out

@final
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,19 @@ def test_groupby_empty_dataset(dtype, kwargs):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("func", ["min", "max"])
def test_min_empty_string_dtype(func):
# GH#55619
pytest.importorskip("pyarrow")
dtype = "string[pyarrow_numpy]"
df = DataFrame({"a": ["a"], "b": "a", "c": "a"}, dtype=dtype).iloc[:0]
result = getattr(df.groupby("a"), func)()
expected = DataFrame(
columns=["b", "c"], dtype=dtype, index=Index([], dtype=dtype, name="a")
)
tm.assert_frame_equal(result, expected)


def test_corrwith_with_1_axis():
# GH 47723
df = DataFrame({"a": [1, 1, 2], "b": [3, 7, 4]})
Expand Down

0 comments on commit ea39f7b

Please sign in to comment.