From a3be335d50d761db825a40d4aa6f1c27827dbf00 Mon Sep 17 00:00:00 2001 From: Kei Date: Mon, 15 Apr 2024 22:58:00 +0800 Subject: [PATCH] Update impl to work with transform result ONLY --- pandas/core/groupby/groupby.py | 217 ++++++++++++++++++++++++++++++--- 1 file changed, 203 insertions(+), 14 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 93c907501c9a37..3b9b09070495bf 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -69,6 +69,7 @@ class providing the base-class of operations. from pandas.core.dtypes.cast import ( coerce_indexer_dtype, ensure_dtype_can_hold_na, + maybe_downcast_to_dtype, ) from pandas.core.dtypes.common import ( is_bool_dtype, @@ -102,6 +103,7 @@ class providing the base-class of operations. IntegerArray, SparseArray, ) +from pandas.core.arrays.datetimes import DatetimeArray from pandas.core.arrays.string_ import StringDtype from pandas.core.arrays.string_arrow import ( ArrowStringArray, @@ -124,6 +126,7 @@ class providing the base-class of operations. GroupByIndexingMixin, GroupByNthSelector, ) +from pandas.core.groupby.ops import WrappedCythonOp from pandas.core.indexes.api import ( Index, MultiIndex, @@ -1945,6 +1948,139 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): return self._wrap_transform_fast_result(result) + @final + def _convert_to_out_dtype_ignore_unobserved( + self, func, result, numeric_only: bool = False + ): + cy_op = WrappedCythonOp( + kind="aggregate", how=func, has_dropped_na=self._grouper.has_dropped_na + ) + result_generator = self._create_result_generator(result) + + def convert_result_dtype(input_values: ArrayLike) -> ArrayLike: + if input_values.ndim == 1: + # expand to 2d, dispatch, then squeeze if appropriate + input_values_2d = input_values[None, :] + converted_result_values = self._convert_result_dtype( + func, input_values_2d, cy_op, result_generator + ) + + if converted_result_values.shape[0] == 1: + return converted_result_values[0] + + return self._convert_result_dtype( + func, input_values, cy_op, result_generator + ) + + input_mgr = self._get_data_to_aggregate(numeric_only=numeric_only, name=func) + new_mgr = input_mgr.grouped_reduce(convert_result_dtype) + res = self._wrap_agged_manager(new_mgr) + out = self._wrap_aggregated_output(res) + return out + + @final + def _create_result_generator(self, result): + for block in result._mgr.blocks: + values = block.values + if values.ndim == 1: + values = values[None, :] + + yield from values + + @final + def _convert_result_dtype( + self, func, input_values: ArrayLike, cy_op, result_generator + ): + orig_input_values = input_values + timezone_free_orig_input_values = self._remove_timezone(orig_input_values) + input_values = self._preprocess_input_values( + func, timezone_free_orig_input_values + ) + input_values = cy_op._get_cython_vals(input_values) + + transposed_input_values = input_values.T + transposed_out_shape = cy_op._get_output_shape( + self._grouper.ngroups, transposed_input_values + ) + out_shape = tuple(np.flip(transposed_out_shape)) + out_dtype = self._get_final_out_dtype( + cy_op, func, timezone_free_orig_input_values, input_values + ) + + converted_result_values = np.empty(out_shape, dtype=out_dtype) + if func not in cy_op.cast_blocklist: + res_dtype = cy_op._get_result_dtype(timezone_free_orig_input_values.dtype) + converted_result_values = maybe_downcast_to_dtype( + converted_result_values, res_dtype + ) + out_dtype = converted_result_values.dtype + + for i in range(out_shape[0]): + result_col = next(result_generator) + + if result_col.dtype == "float64" and out_dtype != "float64": + placeholder = np.empty( + [ + 1, + ], + dtype=out_dtype, + ) + result_col[isna(result_col)] = placeholder[0] + converted_result_values[i] = result_col.astype(out_dtype) + else: + converted_result_values[i] = result_col + + if isinstance(orig_input_values, DatetimeArray): + converted_result_values = orig_input_values._from_backing_data( + converted_result_values + ) + return converted_result_values + + @final + def _remove_timezone(self, orig_input_values: ArrayLike) -> ArrayLike: + if isinstance(orig_input_values, DatetimeArray): + orig_input_values = orig_input_values._ndarray.view("M8[ns]") + + return orig_input_values + + @final + def _preprocess_input_values(self, func, input_values: ArrayLike) -> ArrayLike: + dtype = input_values.dtype + is_datetimelike = dtype.kind in "mM" + + if is_datetimelike: + input_values = input_values.view("int64") + elif dtype.kind == "b": + input_values = input_values.view("uint8") + + if input_values.dtype == "float16": + input_values = input_values.astype(np.float32) + + if func in ["any", "all"]: + input_values = input_values.astype(bool, copy=False).view(np.int8) + + return input_values + + @final + def _get_final_out_dtype( + self, + cy_op, + func, + timezone_free_orig_input_values: ArrayLike, + input_values: ArrayLike, + ): + out_dtype = cy_op._get_out_dtype(timezone_free_orig_input_values.dtype) + if func in ["idxmin", "idxmax"]: + index = self.obj.index + out_dtype = index.dtype + + elif func in ["any", "all"]: + out_dtype = bool + elif timezone_free_orig_input_values.dtype == object and func in ["skew"]: + out_dtype = object + + return out_dtype + @final def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT: """ @@ -2090,12 +2226,16 @@ def any(self, skipna: bool = True) -> NDFrameT: 1 False True 7 True True """ - return self._cython_agg_general( + result = self._cython_agg_general( "any", alt=lambda x: Series(x, copy=False).any(skipna=skipna), skipna=skipna, ) + return self._convert_to_out_dtype_ignore_unobserved( + "any", result, numeric_only=False + ) + @final @Substitution(name="groupby") @Substitution(see_also=_common_see_also) @@ -2148,12 +2288,16 @@ def all(self, skipna: bool = True) -> NDFrameT: 1 False True 7 True True """ - return self._cython_agg_general( + result = self._cython_agg_general( "all", alt=lambda x: Series(x, copy=False).all(skipna=skipna), skipna=skipna, ) + return self._convert_to_out_dtype_ignore_unobserved( + "all", result, numeric_only=False + ) + @final @Substitution(name="groupby") @Substitution(see_also=_common_see_also) @@ -2347,7 +2491,11 @@ def mean( alt=lambda x: Series(x, copy=False).mean(numeric_only=numeric_only), numeric_only=numeric_only, ) - return result.__finalize__(self.obj, method="groupby") + result = result.__finalize__(self.obj, method="groupby") + + return self._convert_to_out_dtype_ignore_unobserved( + "mean", result, numeric_only=numeric_only + ) @final def median(self, numeric_only: bool = False) -> NDFrameT: @@ -2434,7 +2582,11 @@ def median(self, numeric_only: bool = False) -> NDFrameT: alt=lambda x: Series(x, copy=False).median(numeric_only=numeric_only), numeric_only=numeric_only, ) - return result.__finalize__(self.obj, method="groupby") + result = result.__finalize__(self.obj, method="groupby") + + return self._convert_to_out_dtype_ignore_unobserved( + "median", result, numeric_only=numeric_only + ) @final @Substitution(name="groupby") @@ -2539,13 +2691,17 @@ def std( ) ) else: - return self._cython_agg_general( + result = self._cython_agg_general( "std", alt=lambda x: Series(x, copy=False).std(ddof=ddof), numeric_only=numeric_only, ddof=ddof, ) + return self._convert_to_out_dtype_ignore_unobserved( + "std", result, numeric_only=numeric_only + ) + @final @Substitution(name="groupby") @Substitution(see_also=_common_see_also) @@ -2647,13 +2803,17 @@ def var( ddof=ddof, ) else: - return self._cython_agg_general( + result = self._cython_agg_general( "var", alt=lambda x: Series(x, copy=False).var(ddof=ddof), numeric_only=numeric_only, ddof=ddof, ) + return self._convert_to_out_dtype_ignore_unobserved( + "var", result, numeric_only=numeric_only + ) + @final def _value_counts( self, @@ -2868,13 +3028,17 @@ def sem(self, ddof: int = 1, numeric_only: bool = False) -> NDFrameT: f"{type(self).__name__}.sem called with " f"numeric_only={numeric_only} and dtype {self.obj.dtype}" ) - return self._cython_agg_general( + result = self._cython_agg_general( "sem", alt=lambda x: Series(x, copy=False).sem(ddof=ddof), numeric_only=numeric_only, ddof=ddof, ) + return self._convert_to_out_dtype_ignore_unobserved( + "sem", result, numeric_only=numeric_only + ) + @final @Substitution(name="groupby") @Substitution(see_also=_common_see_also) @@ -3042,7 +3206,9 @@ def sum( npfunc=np.sum, ) - return result + return self._convert_to_out_dtype_ignore_unobserved( + "sum", result, numeric_only=numeric_only + ) @final @doc( @@ -3086,10 +3252,14 @@ def sum( ), ) def prod(self, numeric_only: bool = False, min_count: int = 0) -> NDFrameT: - return self._agg_general( + result = self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod ) + return self._convert_to_out_dtype_ignore_unobserved( + "prod", result, numeric_only=numeric_only + ) + @final @doc( _groupby_agg_method_engine_remove_nan_template, @@ -3152,7 +3322,7 @@ def min( is_max=False, ) else: - return self._agg_general( + result = self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="min", @@ -3160,6 +3330,10 @@ def min( **kwargs, ) + return self._convert_to_out_dtype_ignore_unobserved( + "min", result, numeric_only=numeric_only + ) + @final @doc( _groupby_agg_method_engine_template, @@ -3221,13 +3395,17 @@ def max( is_max=True, ) else: - return self._agg_general( + result = self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="max", npfunc=np.max, ) + return self._convert_to_out_dtype_ignore_unobserved( + "max", result, numeric_only=numeric_only + ) + @final def first( self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True @@ -3306,7 +3484,7 @@ def first(x: Series): else: # pragma: no cover raise TypeError(type(obj)) - return self._agg_general( + result = self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="first", @@ -3314,6 +3492,10 @@ def first(x: Series): skipna=skipna, ) + return self._convert_to_out_dtype_ignore_unobserved( + "first", result, numeric_only=numeric_only + ) + @final def last( self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True @@ -3375,7 +3557,7 @@ def last(x: Series): else: # pragma: no cover raise TypeError(type(obj)) - return self._agg_general( + result = self._agg_general( numeric_only=numeric_only, min_count=min_count, alias="last", @@ -3383,6 +3565,10 @@ def last(x: Series): skipna=skipna, ) + return self._convert_to_out_dtype_ignore_unobserved( + "last", result, numeric_only=numeric_only + ) + @final def ohlc(self) -> DataFrame: """ @@ -5575,7 +5761,10 @@ def _idxmax_idxmin( alias=how, skipna=skipna, ) - return result + + return self._convert_to_out_dtype_ignore_unobserved( + how, result, numeric_only=numeric_only + ) def _wrap_idxmax_idxmin(self, res: NDFrameT) -> NDFrameT: index = self.obj.index