diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 4c0f3ddd826b7..6617b3c8b4cca 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -57,6 +57,38 @@ }, } +# These aggregations don't have a kernel implemented for them yet +_numba_unsupported_methods = [ + "all", + "any", + "bfill", + "count", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "describe", + "diff", + "ffill", + "first", + "head", + "last", + "median", + "nunique", + "pct_change", + "prod", + "quantile", + "rank", + "sem", + "shift", + "size", + "skew", + "tail", + "unique", + "value_counts", +] + class ApplyDictReturn: def setup(self): @@ -453,9 +485,10 @@ class GroupByMethods: ], ["direct", "transformation"], [1, 5], + ["cython", "numba"], ] - def setup(self, dtype, method, application, ncols): + def setup(self, dtype, method, application, ncols, engine): if method in method_blocklist.get(dtype, {}): raise NotImplementedError # skip benchmark @@ -474,6 +507,19 @@ def setup(self, dtype, method, application, ncols): # DataFrameGroupBy doesn't have these methods raise NotImplementedError + # Numba currently doesn't support + # multiple transform functions or strs for transform, + # grouping on multiple columns + # and we lack kernels for a bunch of methods + if ( + engine == "numba" + and method in _numba_unsupported_methods + or ncols > 1 + or application == "transformation" + or dtype == "datetime" + ): + raise NotImplementedError + if method == "describe": ngroups = 20 elif method == "skew": @@ -505,17 +551,30 @@ def setup(self, dtype, method, application, ncols): if len(cols) == 1: cols = cols[0] + # Not everything supports the engine keyword yet + kwargs = {} + if engine == "numba": + kwargs["engine"] = engine + if application == "transformation": - self.as_group_method = lambda: df.groupby("key")[cols].transform(method) - self.as_field_method = lambda: df.groupby(cols)["key"].transform(method) + self.as_group_method = lambda: df.groupby("key")[cols].transform( + method, **kwargs + ) + self.as_field_method = lambda: df.groupby(cols)["key"].transform( + method, **kwargs + ) else: - self.as_group_method = getattr(df.groupby("key")[cols], method) - self.as_field_method = getattr(df.groupby(cols)["key"], method) + self.as_group_method = partial( + getattr(df.groupby("key")[cols], method), **kwargs + ) + self.as_field_method = partial( + getattr(df.groupby(cols)["key"], method), **kwargs + ) - def time_dtype_as_group(self, dtype, method, application, ncols): + def time_dtype_as_group(self, dtype, method, application, ncols, engine): self.as_group_method() - def time_dtype_as_field(self, dtype, method, application, ncols): + def time_dtype_as_field(self, dtype, method, application, ncols, engine): self.as_field_method() @@ -532,8 +591,12 @@ class GroupByCythonAgg: [ "sum", "prod", - "min", - "max", + # TODO: uncomment min/max + # Currently, min/max implemented very inefficiently + # because it re-uses the Window min/max kernel + # so it will time out ASVs + # "min", + # "max", "mean", "median", "var", @@ -554,6 +617,22 @@ def time_frame_agg(self, dtype, method): self.df.groupby("key").agg(method) +class GroupByNumbaAgg(GroupByCythonAgg): + """ + Benchmarks specifically targeting our numba aggregation algorithms + (using a big enough dataframe with simple key, so a large part of the + time is actually spent in the grouped aggregation). + """ + + def setup(self, dtype, method): + if method in _numba_unsupported_methods: + raise NotImplementedError + super().setup(dtype, method) + + def time_frame_agg(self, dtype, method): + self.df.groupby("key").agg(method, engine="numba") + + class GroupByCythonAggEaDtypes: """ Benchmarks specifically targeting our cython aggregation algorithms diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 19e314cbf5ed8..137be168985d2 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -108,6 +108,7 @@ Other enhancements - :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply`. When set to ``False`` the supplied callables will always operate on the whole Series (:issue:`53400`). +- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`) - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`) - Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`) diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index b5a611560bde7..24599148356fa 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -3,6 +3,7 @@ import functools from typing import ( TYPE_CHECKING, + Any, Callable, ) @@ -15,8 +16,86 @@ @functools.cache +def make_looper(func, result_dtype, nopython, nogil, parallel): + if TYPE_CHECKING: + import numba + else: + numba = import_optional_dependency("numba") + + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) + def column_looper( + values: np.ndarray, + start: np.ndarray, + end: np.ndarray, + min_periods: int, + *args, + ): + result = np.empty((values.shape[0], len(start)), dtype=result_dtype) + na_positions = {} + for i in numba.prange(values.shape[0]): + output, na_pos = func( + values[i], result_dtype, start, end, min_periods, *args + ) + result[i] = output + if len(na_pos) > 0: + na_positions[i] = np.array(na_pos) + return result, na_positions + + return column_looper + + +default_dtype_mapping: dict[np.dtype, Any] = { + np.dtype("int8"): np.int64, + np.dtype("int16"): np.int64, + np.dtype("int32"): np.int64, + np.dtype("int64"): np.int64, + np.dtype("uint8"): np.uint64, + np.dtype("uint16"): np.uint64, + np.dtype("uint32"): np.uint64, + np.dtype("uint64"): np.uint64, + np.dtype("float32"): np.float64, + np.dtype("float64"): np.float64, + np.dtype("complex64"): np.complex128, + np.dtype("complex128"): np.complex128, +} + + +# TODO: Preserve complex dtypes + +float_dtype_mapping: dict[np.dtype, Any] = { + np.dtype("int8"): np.float64, + np.dtype("int16"): np.float64, + np.dtype("int32"): np.float64, + np.dtype("int64"): np.float64, + np.dtype("uint8"): np.float64, + np.dtype("uint16"): np.float64, + np.dtype("uint32"): np.float64, + np.dtype("uint64"): np.float64, + np.dtype("float32"): np.float64, + np.dtype("float64"): np.float64, + np.dtype("complex64"): np.float64, + np.dtype("complex128"): np.float64, +} + +identity_dtype_mapping: dict[np.dtype, Any] = { + np.dtype("int8"): np.int8, + np.dtype("int16"): np.int16, + np.dtype("int32"): np.int32, + np.dtype("int64"): np.int64, + np.dtype("uint8"): np.uint8, + np.dtype("uint16"): np.uint16, + np.dtype("uint32"): np.uint32, + np.dtype("uint64"): np.uint64, + np.dtype("float32"): np.float32, + np.dtype("float64"): np.float64, + np.dtype("complex64"): np.complex64, + np.dtype("complex128"): np.complex128, +} + + def generate_shared_aggregator( func: Callable[..., Scalar], + dtype_mapping: dict[np.dtype, np.dtype], nopython: bool, nogil: bool, parallel: bool, @@ -29,6 +108,9 @@ def generate_shared_aggregator( ---------- func : function aggregation function to be applied to each column + dtype_mapping: dict or None + If not None, maps a dtype to a result dtype. + Otherwise, will fall back to default mapping. nopython : bool nopython to be passed into numba.jit nogil : bool @@ -40,22 +122,35 @@ def generate_shared_aggregator( ------- Numba function """ - if TYPE_CHECKING: - import numba - else: - numba = import_optional_dependency("numba") - @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) - def column_looper( - values: np.ndarray, - start: np.ndarray, - end: np.ndarray, - min_periods: int, - *args, - ): - result = np.empty((len(start), values.shape[1]), dtype=np.float64) - for i in numba.prange(values.shape[1]): - result[:, i] = func(values[:, i], start, end, min_periods, *args) + # A wrapper around the looper function, + # to dispatch based on dtype since numba is unable to do that in nopython mode + + # It also post-processes the values by inserting nans where number of observations + # is less than min_periods + # Cannot do this in numba nopython mode + # (you'll run into type-unification error when you cast int -> float) + def looper_wrapper(values, start, end, min_periods, **kwargs): + result_dtype = dtype_mapping[values.dtype] + column_looper = make_looper(func, result_dtype, nopython, nogil, parallel) + # Need to unpack kwargs since numba only supports *args + result, na_positions = column_looper( + values, start, end, min_periods, *kwargs.values() + ) + if result.dtype.kind == "i": + # Look if na_positions is not empty + # If so, convert the whole block + # This is OK since int dtype cannot hold nan, + # so if min_periods not satisfied for 1 col, it is not satisfied for + # all columns at that index + for na_pos in na_positions.values(): + if len(na_pos) > 0: + result = result.astype("float64") + break + # TODO: Optimize this + for i, na_pos in na_positions.items(): + if len(na_pos) > 0: + result[i, na_pos] = np.nan return result - return column_looper + return looper_wrapper diff --git a/pandas/core/_numba/kernels/mean_.py b/pandas/core/_numba/kernels/mean_.py index 725989e093441..8774ff72af852 100644 --- a/pandas/core/_numba/kernels/mean_.py +++ b/pandas/core/_numba/kernels/mean_.py @@ -60,10 +60,11 @@ def remove_mean( @numba.jit(nopython=True, nogil=True, parallel=False) def sliding_mean( values: np.ndarray, + result_dtype: np.dtype, start: np.ndarray, end: np.ndarray, min_periods: int, -) -> np.ndarray: +) -> tuple[np.ndarray, list[int]]: N = len(start) nobs = 0 sum_x = 0.0 @@ -75,7 +76,7 @@ def sliding_mean( start ) and is_monotonic_increasing(end) - output = np.empty(N, dtype=np.float64) + output = np.empty(N, dtype=result_dtype) for i in range(N): s = start[i] @@ -147,4 +148,8 @@ def sliding_mean( neg_ct = 0 compensation_remove = 0.0 - return output + # na_position is empty list since float64 can already hold nans + # Do list comprehension, since numba cannot figure out that na_pos is + # empty list of ints on its own + na_pos = [0 for i in range(0)] + return output, na_pos diff --git a/pandas/core/_numba/kernels/min_max_.py b/pandas/core/_numba/kernels/min_max_.py index acba66a6e4f63..814deeee9d0d5 100644 --- a/pandas/core/_numba/kernels/min_max_.py +++ b/pandas/core/_numba/kernels/min_max_.py @@ -15,14 +15,16 @@ @numba.jit(nopython=True, nogil=True, parallel=False) def sliding_min_max( values: np.ndarray, + result_dtype: np.dtype, start: np.ndarray, end: np.ndarray, min_periods: int, is_max: bool, -) -> np.ndarray: +) -> tuple[np.ndarray, list[int]]: N = len(start) nobs = 0 - output = np.empty(N, dtype=np.float64) + output = np.empty(N, dtype=result_dtype) + na_pos = [] # Use deque once numba supports it # https://github.com/numba/numba/issues/7417 Q: list = [] @@ -64,6 +66,9 @@ def sliding_min_max( if Q and curr_win_size > 0 and nobs >= min_periods: output[i] = values[Q[0]] else: - output[i] = np.nan + if values.dtype.kind != "i": + output[i] = np.nan + else: + na_pos.append(i) - return output + return output, na_pos diff --git a/pandas/core/_numba/kernels/sum_.py b/pandas/core/_numba/kernels/sum_.py index 056897189fe67..e834f1410f51a 100644 --- a/pandas/core/_numba/kernels/sum_.py +++ b/pandas/core/_numba/kernels/sum_.py @@ -8,6 +8,8 @@ """ from __future__ import annotations +from typing import Any + import numba import numpy as np @@ -16,13 +18,13 @@ @numba.jit(nopython=True, nogil=True, parallel=False) def add_sum( - val: float, + val: Any, nobs: int, - sum_x: float, - compensation: float, + sum_x: Any, + compensation: Any, num_consecutive_same_value: int, - prev_value: float, -) -> tuple[int, float, float, int, float]: + prev_value: Any, +) -> tuple[int, Any, Any, int, Any]: if not np.isnan(val): nobs += 1 y = val - compensation @@ -41,8 +43,8 @@ def add_sum( @numba.jit(nopython=True, nogil=True, parallel=False) def remove_sum( - val: float, nobs: int, sum_x: float, compensation: float -) -> tuple[int, float, float]: + val: Any, nobs: int, sum_x: Any, compensation: Any +) -> tuple[int, Any, Any]: if not np.isnan(val): nobs -= 1 y = -val - compensation @@ -55,21 +57,29 @@ def remove_sum( @numba.jit(nopython=True, nogil=True, parallel=False) def sliding_sum( values: np.ndarray, + result_dtype: np.dtype, start: np.ndarray, end: np.ndarray, min_periods: int, -) -> np.ndarray: +) -> tuple[np.ndarray, list[int]]: + dtype = values.dtype + + na_val: object = np.nan + if dtype.kind == "i": + na_val = 0 + N = len(start) nobs = 0 - sum_x = 0.0 - compensation_add = 0.0 - compensation_remove = 0.0 + sum_x = 0 + compensation_add = 0 + compensation_remove = 0 + na_pos = [] is_monotonic_increasing_bounds = is_monotonic_increasing( start ) and is_monotonic_increasing(end) - output = np.empty(N, dtype=np.float64) + output = np.empty(N, dtype=result_dtype) for i in range(N): s = start[i] @@ -119,20 +129,22 @@ def sliding_sum( ) if nobs == 0 == min_periods: - result = 0.0 + result: object = 0 elif nobs >= min_periods: if num_consecutive_same_value >= nobs: result = prev_value * nobs else: result = sum_x else: - result = np.nan + result = na_val + if dtype.kind == "i": + na_pos.append(i) output[i] = result if not is_monotonic_increasing_bounds: nobs = 0 - sum_x = 0.0 - compensation_remove = 0.0 + sum_x = 0 + compensation_remove = 0 - return output + return output, na_pos diff --git a/pandas/core/_numba/kernels/var_.py b/pandas/core/_numba/kernels/var_.py index d3243f4928dca..e0f46ba6e3805 100644 --- a/pandas/core/_numba/kernels/var_.py +++ b/pandas/core/_numba/kernels/var_.py @@ -68,11 +68,12 @@ def remove_var( @numba.jit(nopython=True, nogil=True, parallel=False) def sliding_var( values: np.ndarray, + result_dtype: np.dtype, start: np.ndarray, end: np.ndarray, min_periods: int, ddof: int = 1, -) -> np.ndarray: +) -> tuple[np.ndarray, list[int]]: N = len(start) nobs = 0 mean_x = 0.0 @@ -85,7 +86,7 @@ def sliding_var( start ) and is_monotonic_increasing(end) - output = np.empty(N, dtype=np.float64) + output = np.empty(N, dtype=result_dtype) for i in range(N): s = start[i] @@ -154,4 +155,8 @@ def sliding_var( ssqdm_x = 0.0 compensation_remove = 0.0 - return output + # na_position is empty list since float64 can already hold nans + # Do list comprehension, since numba cannot figure out that na_pos is + # empty list of ints on its own + na_pos = [0 for i in range(0)] + return output, na_pos diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e447377db9e55..5b6d28ac9ab4a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -135,6 +135,8 @@ class providing the base-class of operations. ) if TYPE_CHECKING: + from typing import Any + from pandas.core.window import ( ExpandingGroupby, ExponentialMovingWindowGroupby, @@ -1480,8 +1482,9 @@ def _numba_prep(self, data: DataFrame): def _numba_agg_general( self, func: Callable, + dtype_mapping: dict[np.dtype, Any], engine_kwargs: dict[str, bool] | None, - *aggregator_args, + **aggregator_kwargs, ): """ Perform groupby with a standard numerical aggregation function (e.g. mean) @@ -1496,19 +1499,26 @@ def _numba_agg_general( data = self._obj_with_exclusions df = data if data.ndim == 2 else data.to_frame() - starts, ends, sorted_index, sorted_data = self._numba_prep(df) + + sorted_df = df.take(self.grouper._sort_idx, axis=self.axis) + sorted_ids = self.grouper._sorted_ids + _, _, ngroups = self.grouper.group_info + starts, ends = lib.generate_slices(sorted_ids, ngroups) aggregator = executor.generate_shared_aggregator( - func, **get_jit_arguments(engine_kwargs) + func, dtype_mapping, **get_jit_arguments(engine_kwargs) + ) + result = sorted_df._mgr.apply( + aggregator, start=starts, end=ends, **aggregator_kwargs ) - result = aggregator(sorted_data, starts, ends, 0, *aggregator_args) + result.axes[1] = self.grouper.result_index + result = df._constructor(result) - index = self.grouper.result_index if data.ndim == 1: - result_kwargs = {"name": data.name} - result = result.ravel() + result = result.squeeze("columns") + result.name = data.name else: - result_kwargs = {"columns": data.columns} - return data._constructor(result, index=index, **result_kwargs) + result.columns = data.columns + return result @final def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs): @@ -2189,7 +2199,9 @@ def mean( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_mean - return self._numba_agg_general(sliding_mean, engine_kwargs) + return self._numba_agg_general( + sliding_mean, executor.float_dtype_mapping, engine_kwargs, min_periods=0 + ) else: result = self._cython_agg_general( "mean", @@ -2356,7 +2368,15 @@ def std( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var - return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof)) + return np.sqrt( + self._numba_agg_general( + sliding_var, + executor.float_dtype_mapping, + engine_kwargs, + min_periods=0, + ddof=ddof, + ) + ) else: return self._cython_agg_general( "std", @@ -2457,7 +2477,13 @@ def var( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var - return self._numba_agg_general(sliding_var, engine_kwargs, ddof) + return self._numba_agg_general( + sliding_var, + executor.float_dtype_mapping, + engine_kwargs, + min_periods=0, + ddof=ddof, + ) else: return self._cython_agg_general( "var", @@ -2786,7 +2812,9 @@ def sum( return self._numba_agg_general( sliding_sum, + executor.default_dtype_mapping, engine_kwargs, + min_periods=min_count, ) else: # If we are grouping on categoricals we want unobserved categories to @@ -2899,7 +2927,13 @@ def min( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max - return self._numba_agg_general(sliding_min_max, engine_kwargs, False) + return self._numba_agg_general( + sliding_min_max, + executor.identity_dtype_mapping, + engine_kwargs, + min_periods=min_count, + is_max=False, + ) else: return self._agg_general( numeric_only=numeric_only, @@ -2959,7 +2993,13 @@ def max( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max - return self._numba_agg_general(sliding_min_max, engine_kwargs, True) + return self._numba_agg_general( + sliding_min_max, + executor.identity_dtype_mapping, + engine_kwargs, + min_periods=min_count, + is_max=True, + ) else: return self._agg_general( numeric_only=numeric_only, diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 7220b44c7af9d..a08ffcc9f7200 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -624,7 +624,7 @@ def _numba_apply( self, func: Callable[..., Any], engine_kwargs: dict[str, bool] | None = None, - *func_args, + **func_kwargs, ): window_indexer = self._get_window_indexer() min_periods = ( @@ -646,10 +646,15 @@ def _numba_apply( step=self.step, ) self._check_window_bounds(start, end, len(values)) + # For now, map everything to float to match the Cython impl + # even though it is wrong + # TODO: Could preserve correct dtypes in future + # xref #53214 + dtype_mapping = executor.float_dtype_mapping aggregator = executor.generate_shared_aggregator( - func, **get_jit_arguments(engine_kwargs) + func, dtype_mapping, **get_jit_arguments(engine_kwargs) ) - result = aggregator(values, start, end, min_periods, *func_args) + result = aggregator(values.T, start, end, min_periods, **func_kwargs).T result = result.T if self.axis == 1 else result index = self._slice_axis_for_step(obj.index, result) if obj.ndim == 1: @@ -1466,7 +1471,7 @@ def max( else: from pandas.core._numba.kernels import sliding_min_max - return self._numba_apply(sliding_min_max, engine_kwargs, True) + return self._numba_apply(sliding_min_max, engine_kwargs, is_max=True) window_func = window_aggregations.roll_max return self._apply(window_func, name="max", numeric_only=numeric_only) @@ -1488,7 +1493,7 @@ def min( else: from pandas.core._numba.kernels import sliding_min_max - return self._numba_apply(sliding_min_max, engine_kwargs, False) + return self._numba_apply(sliding_min_max, engine_kwargs, is_max=False) window_func = window_aggregations.roll_min return self._apply(window_func, name="min", numeric_only=numeric_only) @@ -1547,7 +1552,7 @@ def std( raise NotImplementedError("std not supported with method='table'") from pandas.core._numba.kernels import sliding_var - return zsqrt(self._numba_apply(sliding_var, engine_kwargs, ddof)) + return zsqrt(self._numba_apply(sliding_var, engine_kwargs, ddof=ddof)) window_func = window_aggregations.roll_var def zsqrt_func(values, begin, end, min_periods): @@ -1571,7 +1576,7 @@ def var( raise NotImplementedError("var not supported with method='table'") from pandas.core._numba.kernels import sliding_var - return self._numba_apply(sliding_var, engine_kwargs, ddof) + return self._numba_apply(sliding_var, engine_kwargs, ddof=ddof) window_func = partial(window_aggregations.roll_var, ddof=ddof) return self._apply( window_func, diff --git a/pandas/io/formats/info.py b/pandas/io/formats/info.py index 55dacd0c268ff..260620e145105 100644 --- a/pandas/io/formats/info.py +++ b/pandas/io/formats/info.py @@ -14,8 +14,6 @@ Sequence, ) -import numpy as np - from pandas._config import get_option from pandas.io.formats import format as fmt @@ -1099,4 +1097,4 @@ def _get_dataframe_dtype_counts(df: DataFrame) -> Mapping[str, int]: Create mapping between datatypes and their number of occurrences. """ # groupby dtype.name to collect e.g. Categorical columns - return df.dtypes.value_counts().groupby(lambda x: x.name).sum().astype(np.intp) + return df.dtypes.value_counts().groupby(lambda x: x.name).sum() diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index 10ed32a334d18..2514e988e4e80 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -157,8 +157,7 @@ def test_multifunc_numba_vs_cython_frame(agg_kwargs): grouped = data.groupby(0) result = grouped.agg(**agg_kwargs, engine="numba") expected = grouped.agg(**agg_kwargs, engine="cython") - # check_dtype can be removed if GH 44952 is addressed - tm.assert_frame_equal(result, expected, check_dtype=False) + tm.assert_frame_equal(result, expected) @td.skip_if_no("numba") @@ -194,6 +193,7 @@ def test_multifunc_numba_udf_frame(agg_kwargs, expected_func): result = grouped.agg(**agg_kwargs, engine="numba") expected = grouped.agg(expected_func, engine="cython") # check_dtype can be removed if GH 44952 is addressed + # Currently, UDFs still always return float64 while reductions can preserve dtype tm.assert_frame_equal(result, expected, check_dtype=False) diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py index 7e7b97d9273dc..c5e30513f69de 100644 --- a/pandas/tests/groupby/conftest.py +++ b/pandas/tests/groupby/conftest.py @@ -196,8 +196,23 @@ def nopython(request): ("sum", {}), ("min", {}), ("max", {}), + ("sum", {"min_count": 2}), + ("min", {"min_count": 2}), + ("max", {"min_count": 2}), + ], + ids=[ + "mean", + "var_1", + "var_0", + "std_1", + "std_0", + "sum", + "min", + "max", + "sum-min_count", + "min-min_count", + "max-min_count", ], - ids=["mean", "var_1", "var_0", "std_1", "std_0", "sum", "min", "max"], ) def numba_supported_reductions(request): """reductions supported with engine='numba'""" diff --git a/pandas/tests/groupby/test_numba.py b/pandas/tests/groupby/test_numba.py index 867bdbf583388..7d4440b595dff 100644 --- a/pandas/tests/groupby/test_numba.py +++ b/pandas/tests/groupby/test_numba.py @@ -26,9 +26,7 @@ def test_cython_vs_numba_frame( engine="numba", engine_kwargs=engine_kwargs, **kwargs ) expected = getattr(gb, func)(**kwargs) - # check_dtype can be removed if GH 44952 is addressed - check_dtype = func not in ("sum", "min", "max") - tm.assert_frame_equal(result, expected, check_dtype=check_dtype) + tm.assert_frame_equal(result, expected) def test_cython_vs_numba_getitem( self, sort, nogil, parallel, nopython, numba_supported_reductions @@ -41,9 +39,7 @@ def test_cython_vs_numba_getitem( engine="numba", engine_kwargs=engine_kwargs, **kwargs ) expected = getattr(gb, func)(**kwargs) - # check_dtype can be removed if GH 44952 is addressed - check_dtype = func not in ("sum", "min", "max") - tm.assert_series_equal(result, expected, check_dtype=check_dtype) + tm.assert_series_equal(result, expected) def test_cython_vs_numba_series( self, sort, nogil, parallel, nopython, numba_supported_reductions @@ -56,9 +52,7 @@ def test_cython_vs_numba_series( engine="numba", engine_kwargs=engine_kwargs, **kwargs ) expected = getattr(gb, func)(**kwargs) - # check_dtype can be removed if GH 44952 is addressed - check_dtype = func not in ("sum", "min", "max") - tm.assert_series_equal(result, expected, check_dtype=check_dtype) + tm.assert_series_equal(result, expected) def test_as_index_false_unsupported(self, numba_supported_reductions): func, kwargs = numba_supported_reductions