From a52e7fed90b025e7826bc44d4f22a0262edceed0 Mon Sep 17 00:00:00 2001 From: Kei Date: Wed, 17 Apr 2024 17:01:13 +0800 Subject: [PATCH] Temporarily change observed=True, for groupby.transform --- pandas/core/groupby/generic.py | 3 ++ pandas/core/groupby/groupby.py | 80 +++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 0a048d11d0b4d..23b785f282ca1 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -2044,8 +2044,11 @@ def _gotitem(self, key, ndim: int, subset=None): elif ndim == 1: if subset is None: subset = self.obj[key] + + orig_obj = self.orig_obj if not self.observed else None return SeriesGroupBy( subset, + orig_obj, self.keys, level=self.level, grouper=self._grouper, diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index bc37405b25a16..ec6c92792dd2e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1096,6 +1096,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): def __init__( self, obj: NDFrameT, + orig_obj: NDFrameT | None = None, keys: _KeysArgType | None = None, level: IndexLabel | None = None, grouper: ops.BaseGrouper | None = None, @@ -1117,6 +1118,7 @@ def __init__( self.sort = sort self.group_keys = group_keys self.dropna = dropna + self.orig_obj = obj if orig_obj is None else orig_obj if grouper is None: grouper, exclusions, obj = get_grouper( @@ -1879,24 +1881,70 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): else: # i.e. func in base.reduction_kernels + if self.observed: + return self._reduction_kernel_transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) - # GH#30918 Use _transform_fast only when we know func is an aggregation - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - with com.temp_setattr(self, "as_index", True): - # GH#49834 - result needs groups in the index for - # _wrap_transform_fast_result - if func in ["idxmin", "idxmax"]: - func = cast(Literal["idxmin", "idxmax"], func) - result = self._idxmax_idxmin(func, True, *args, **kwargs) - else: - if engine is not None: - kwargs["engine"] = engine - kwargs["engine_kwargs"] = engine_kwargs - result = getattr(self, func)(*args, **kwargs) + grouper, exclusions, obj = get_grouper( + self.orig_obj, + self.keys, + level=self.level, + sort=self.sort, + observed=True, + dropna=self.dropna, + ) + exclusions = frozenset(exclusions) if exclusions else frozenset() + obj_has_not_changed = self.orig_obj.equals(self.obj) + + with ( + com.temp_setattr(self, "observed", True), + com.temp_setattr(self, "_grouper", grouper), + com.temp_setattr(self, "exclusions", exclusions), + com.temp_setattr(self, "obj", obj, condition=obj_has_not_changed), + ): + return self._reduction_kernel_transform( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + + # with com.temp_setattr(self, "as_index", True): + # # GH#49834 - result needs groups in the index for + # # _wrap_transform_fast_result + # if func in ["idxmin", "idxmax"]: + # func = cast(Literal["idxmin", "idxmax"], func) + # result = self._idxmax_idxmin(func, True, *args, **kwargs) + # else: + # if engine is not None: + # kwargs["engine"] = engine + # kwargs["engine_kwargs"] = engine_kwargs + # result = getattr(self, func)(*args, **kwargs) + + # print("result with observed = False\n", result.to_string()) + # r = self._wrap_transform_fast_result(result) + # print("reindexed result", r.to_string()) + # return r + + @final + def _reduction_kernel_transform( + self, func, *args, engine=None, engine_kwargs=None, **kwargs + ): + # GH#30918 Use _transform_fast only when we know func is an aggregation + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. + with com.temp_setattr(self, "as_index", True): + # GH#49834 - result needs groups in the index for + # _wrap_transform_fast_result + if func in ["idxmin", "idxmax"]: + func = cast(Literal["idxmin", "idxmax"], func) + result = self._idxmax_idxmin(func, True, *args, **kwargs) + else: + if engine is not None: + kwargs["engine"] = engine + kwargs["engine_kwargs"] = engine_kwargs + result = getattr(self, func)(*args, **kwargs) - return self._wrap_transform_fast_result(result) + return self._wrap_transform_fast_result(result) @final def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT: