Skip to content

Commit

Permalink
Temporarily change observed=True, for groupby.transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Kei committed Apr 17, 2024
1 parent 888b6bc commit a52e7fe
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 16 deletions.
3 changes: 3 additions & 0 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,8 +2044,11 @@ def _gotitem(self, key, ndim: int, subset=None):
elif ndim == 1:
if subset is None:
subset = self.obj[key]

orig_obj = self.orig_obj if not self.observed else None
return SeriesGroupBy(
subset,
orig_obj,
self.keys,
level=self.level,
grouper=self._grouper,
Expand Down
80 changes: 64 additions & 16 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
def __init__(
self,
obj: NDFrameT,
orig_obj: NDFrameT | None = None,
keys: _KeysArgType | None = None,
level: IndexLabel | None = None,
grouper: ops.BaseGrouper | None = None,
Expand All @@ -1117,6 +1118,7 @@ def __init__(
self.sort = sort
self.group_keys = group_keys
self.dropna = dropna
self.orig_obj = obj if orig_obj is None else orig_obj

if grouper is None:
grouper, exclusions, obj = get_grouper(
Expand Down Expand Up @@ -1879,24 +1881,70 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

else:
# i.e. func in base.reduction_kernels
if self.observed:
return self._reduction_kernel_transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

# GH#30918 Use _transform_fast only when we know func is an aggregation
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
with com.temp_setattr(self, "as_index", True):
# GH#49834 - result needs groups in the index for
# _wrap_transform_fast_result
if func in ["idxmin", "idxmax"]:
func = cast(Literal["idxmin", "idxmax"], func)
result = self._idxmax_idxmin(func, True, *args, **kwargs)
else:
if engine is not None:
kwargs["engine"] = engine
kwargs["engine_kwargs"] = engine_kwargs
result = getattr(self, func)(*args, **kwargs)
grouper, exclusions, obj = get_grouper(
self.orig_obj,
self.keys,
level=self.level,
sort=self.sort,
observed=True,
dropna=self.dropna,
)
exclusions = frozenset(exclusions) if exclusions else frozenset()
obj_has_not_changed = self.orig_obj.equals(self.obj)

with (
com.temp_setattr(self, "observed", True),
com.temp_setattr(self, "_grouper", grouper),
com.temp_setattr(self, "exclusions", exclusions),
com.temp_setattr(self, "obj", obj, condition=obj_has_not_changed),
):
return self._reduction_kernel_transform(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

# with com.temp_setattr(self, "as_index", True):
# # GH#49834 - result needs groups in the index for
# # _wrap_transform_fast_result
# if func in ["idxmin", "idxmax"]:
# func = cast(Literal["idxmin", "idxmax"], func)
# result = self._idxmax_idxmin(func, True, *args, **kwargs)
# else:
# if engine is not None:
# kwargs["engine"] = engine
# kwargs["engine_kwargs"] = engine_kwargs
# result = getattr(self, func)(*args, **kwargs)

# print("result with observed = False\n", result.to_string())
# r = self._wrap_transform_fast_result(result)
# print("reindexed result", r.to_string())
# return r

@final
def _reduction_kernel_transform(
self, func, *args, engine=None, engine_kwargs=None, **kwargs
):
# GH#30918 Use _transform_fast only when we know func is an aggregation
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
with com.temp_setattr(self, "as_index", True):
# GH#49834 - result needs groups in the index for
# _wrap_transform_fast_result
if func in ["idxmin", "idxmax"]:
func = cast(Literal["idxmin", "idxmax"], func)
result = self._idxmax_idxmin(func, True, *args, **kwargs)
else:
if engine is not None:
kwargs["engine"] = engine
kwargs["engine_kwargs"] = engine_kwargs
result = getattr(self, func)(*args, **kwargs)

return self._wrap_transform_fast_result(result)
return self._wrap_transform_fast_result(result)

@final
def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:
Expand Down

0 comments on commit a52e7fe

Please sign in to comment.