Skip to content

Commit

Permalink
Account for categorical dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Kei committed Apr 16, 2024
1 parent a3be335 commit b3520d1
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class providing the base-class of operations.
needs_i8_conversion,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.missing import (
isna,
na_value_for_dtype,
Expand Down Expand Up @@ -2009,7 +2010,7 @@ def _convert_result_dtype(

converted_result_values = np.empty(out_shape, dtype=out_dtype)
if func not in cy_op.cast_blocklist:
res_dtype = cy_op._get_result_dtype(timezone_free_orig_input_values.dtype)
res_dtype = cy_op._get_result_dtype(input_values.dtype)
converted_result_values = maybe_downcast_to_dtype(
converted_result_values, res_dtype
)
Expand Down Expand Up @@ -2052,9 +2053,11 @@ def _preprocess_input_values(self, func, input_values: ArrayLike) -> ArrayLike:
input_values = input_values.view("int64")
elif dtype.kind == "b":
input_values = input_values.view("uint8")

if input_values.dtype == "float16":
elif input_values.dtype == "float16":
input_values = input_values.astype(np.float32)
elif isinstance(dtype, CategoricalDtype):
input_values = input_values[0].astype(bool)
input_values = input_values[None, :]

if func in ["any", "all"]:
input_values = input_values.astype(bool, copy=False).view(np.int8)
Expand Down

0 comments on commit b3520d1

Please sign in to comment.