From b3520d1d6fc9b4472d2e0b6ea337a1c04dc17bfc Mon Sep 17 00:00:00 2001 From: Kei Date: Tue, 16 Apr 2024 17:54:09 +0800 Subject: [PATCH] Account for categorical dtype --- pandas/core/groupby/groupby.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3b9b09070495bf..6c5e02288322a9 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -84,6 +84,7 @@ class providing the base-class of operations. needs_i8_conversion, pandas_dtype, ) +from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.missing import ( isna, na_value_for_dtype, @@ -2009,7 +2010,7 @@ def _convert_result_dtype( converted_result_values = np.empty(out_shape, dtype=out_dtype) if func not in cy_op.cast_blocklist: - res_dtype = cy_op._get_result_dtype(timezone_free_orig_input_values.dtype) + res_dtype = cy_op._get_result_dtype(input_values.dtype) converted_result_values = maybe_downcast_to_dtype( converted_result_values, res_dtype ) @@ -2052,9 +2053,11 @@ def _preprocess_input_values(self, func, input_values: ArrayLike) -> ArrayLike: input_values = input_values.view("int64") elif dtype.kind == "b": input_values = input_values.view("uint8") - - if input_values.dtype == "float16": + elif input_values.dtype == "float16": input_values = input_values.astype(np.float32) + elif isinstance(dtype, CategoricalDtype): + input_values = input_values[0].astype(bool) + input_values = input_values[None, :] if func in ["any", "all"]: input_values = input_values.astype(bool, copy=False).view(np.int8)