diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index d7d29665950a6..c61b8f3fb3701 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -54,6 +54,7 @@ Other enhancements - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`) +- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`) - :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`) - :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`) - :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index af2d6243ce4ed..af513d49bcfe0 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -994,6 +994,7 @@ def wrapper(*args, **kwargs): self.func, # type: ignore[arg-type] self.args, self.kwargs, + num_required_args=1, ) # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable @@ -1001,7 +1002,7 @@ def wrapper(*args, **kwargs): # list[Callable[..., Any] | str]]"; expected "Hashable" nb_looper = generate_apply_looper( self.func, # type: ignore[arg-type] - **get_jit_arguments(engine_kwargs, kwargs), + **get_jit_arguments(engine_kwargs), ) result = nb_looper(self.values, self.axis, *args) # If we made the result 2-D, squeeze it back to 1-D @@ -1158,9 +1159,11 @@ def numba_func(values, col_names, df_index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) nb_func = self.generate_numba_apply_func( - func, **get_jit_arguments(self.engine_kwargs, kwargs) + func, **get_jit_arguments(self.engine_kwargs) ) from pandas.core._numba.extensions import set_numba_data @@ -1298,9 +1301,11 @@ def numba_func(values, col_names_index, index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) nb_func = self.generate_numba_apply_func( - func, **get_jit_arguments(self.engine_kwargs, kwargs) + func, **get_jit_arguments(self.engine_kwargs) ) from pandas.core._numba.extensions import set_numba_data diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index a0bd25525c55f..66db033596872 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -136,6 +136,7 @@ class providing the base-class of operations. from pandas.core.util.numba_ import ( get_jit_arguments, maybe_use_numba, + prepare_function_arguments, ) if TYPE_CHECKING: @@ -1289,8 +1290,11 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) + args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=2 + ) numba_transform_func = numba_.generate_numba_transform_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) result = numba_transform_func( sorted_data, @@ -1325,8 +1329,11 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) + args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=2 + ) numba_agg_func = numba_.generate_numba_agg_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) result = numba_agg_func( sorted_data, diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index de024f612516b..d3f00c08e0e2c 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -29,9 +29,7 @@ def set_use_numba(enable: bool = False) -> None: GLOBAL_USE_NUMBA = enable -def get_jit_arguments( - engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None -) -> dict[str, bool]: +def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. @@ -39,8 +37,6 @@ def get_jit_arguments( ---------- engine_kwargs : dict, default None user passed keyword arguments for numba.JIT - kwargs : dict, default None - user passed keyword arguments to pass into the JITed function Returns ------- @@ -55,16 +51,6 @@ def get_jit_arguments( engine_kwargs = {} nopython = engine_kwargs.get("nopython", True) - if kwargs: - # Note: in case numba supports keyword-only arguments in - # a future version, we should remove this check. But this - # seems unlikely to happen soon. - - raise NumbaUtilError( - "numba does not support keyword-only arguments" - "https://github.com/numba/numba/issues/2916, " - "https://github.com/numba/numba/issues/6846" - ) nogil = engine_kwargs.get("nogil", False) parallel = engine_kwargs.get("parallel", False) return {"nopython": nopython, "nogil": nogil, "parallel": parallel} @@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable: def prepare_function_arguments( - func: Callable, args: tuple, kwargs: dict + func: Callable, args: tuple, kwargs: dict, *, num_required_args: int ) -> tuple[tuple, dict]: """ Prepare arguments for jitted function. As numba functions do not support kwargs, @@ -118,11 +104,17 @@ def prepare_function_arguments( Parameters ---------- func : function - user defined function + User defined function args : tuple - user input positional arguments + User input positional arguments kwargs : dict - user input keyword arguments + User input keyword arguments + num_required_args : int + The number of leading positional arguments we will pass to udf. + These are not supplied by the user. + e.g. for groupby we require "values", "index" as the first two arguments: + `numba_func(group, group_index, *args)`, in this case num_required_args=2. + See :func:`pandas.core.groupby.numba_.generate_numba_agg_func` Returns ------- @@ -133,9 +125,9 @@ def prepare_function_arguments( if not kwargs: return args, kwargs - # the udf should have this pattern: def udf(value, *args, **kwargs):... + # the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):... signature = inspect.signature(func) - arguments = signature.bind(_sentinel, *args, **kwargs) + arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs) arguments.apply_defaults() # Ref: https://peps.python.org/pep-0362/ # Arguments which could be passed as part of either *args or **kwargs @@ -143,7 +135,16 @@ def prepare_function_arguments( args = arguments.args kwargs = arguments.kwargs - assert args[0] is _sentinel - args = args[1:] + if kwargs: + # Note: in case numba supports keyword-only arguments in + # a future version, we should remove this check. But this + # seems unlikely to happen soon. + + raise NumbaUtilError( + "numba does not support keyword-only arguments" + "https://github.com/numba/numba/issues/2916, " + "https://github.com/numba/numba/issues/6846" + ) + args = args[num_required_args:] return args, kwargs diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index cf74cc30f3c5d..b1c37ab48fa57 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -65,6 +65,7 @@ from pandas.core.util.numba_ import ( get_jit_arguments, maybe_use_numba, + prepare_function_arguments, ) from pandas.core.window.common import ( flex_binary_moment, @@ -1472,14 +1473,16 @@ def apply( if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") - numba_args = args + numba_args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=1 + ) if self.method == "single": apply_func = generate_numba_apply_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) else: apply_func = generate_numba_table_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) elif engine in ("cython", None): if engine_kwargs is not None: diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index ed7eae4502a64..d36d723c4be6a 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -90,6 +90,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython): tm.assert_frame_equal(result, expected) if engine == "numba": + # py signature binding + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + float_frame.apply( + lambda x, a: x + a, + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + # keyword-only arguments are not supported in numba with pytest.raises( pd.errors.NumbaUtilError, diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index 964a80f8f3310..15c1efe5fd1ff 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -35,18 +35,43 @@ def incorrect_function(x): def test_check_nopython_kwargs(): pytest.importorskip("numba") - def incorrect_function(values, index): - return sum(values) * 2.7 + def incorrect_function(values, index, *, a): + return sum(values) * 2.7 + a + + def correct_function(values, index, a): + return sum(values) * 2.7 + a data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) + expected = data.groupby("key").sum() * 2.7 + + # py signature binding + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): + data.groupby("key").agg(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").agg(correct_function, engine="numba", b=1) + + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): + data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].agg(correct_function, engine="numba", b=1) + + # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").agg(incorrect_function, engine="numba", a=1) + actual = data.groupby("key").agg(correct_function, engine="numba", a=1) + tm.assert_frame_equal(expected + 1, actual) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1) + actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1) + tm.assert_series_equal(expected["data"] + 1, actual) @pytest.mark.filterwarnings("ignore") diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index a17d25b2e7e2e..969df8ef4c52b 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -33,18 +33,43 @@ def incorrect_function(x): def test_check_nopython_kwargs(): pytest.importorskip("numba") - def incorrect_function(values, index): - return values + 1 + def incorrect_function(values, index, *, a): + return values + a + + def correct_function(values, index, a): + return values + a data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) + # py signature binding + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): + data.groupby("key").transform(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").transform(correct_function, engine="numba", b=1) + + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): + data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].transform(correct_function, engine="numba", b=1) + + # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").transform(incorrect_function, engine="numba", a=1) + actual = data.groupby("key").transform(correct_function, engine="numba", a=1) + tm.assert_frame_equal(data[["data"]] + 1, actual) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) + actual = data.groupby("key")["data"].transform( + correct_function, engine="numba", a=1 + ) + tm.assert_series_equal(data["data"] + 1, actual) @pytest.mark.filterwarnings("ignore") diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 23b17c651f08d..d9ab4723a8f2c 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -38,6 +38,11 @@ def arithmetic_numba_supported_operators(request): return request.param +@pytest.fixture +def roll_frame(): + return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}) + + @td.skip_if_no("numba") @pytest.mark.filterwarnings("ignore") # Filter warnings when parallel=True and the function can't be parallelized by Numba @@ -67,6 +72,62 @@ def f(x, *args): ) tm.assert_series_equal(result, expected) + def test_apply_numba_with_kwargs(self, roll_frame): + # GH 58995 + # rolling apply + def func(sr, a=0): + return sr.sum() + a + + data = DataFrame(range(10)) + + result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # expanding apply + + result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.expanding().apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # groupby rolling + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + # groupby expanding + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + def test_numba_min_periods(self): # GH 58868 def last_row(x): @@ -319,13 +380,24 @@ def f(x): @td.skip_if_no("numba") def test_invalid_kwargs_nopython(): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'a'"): + Series(range(1)).rolling(1).apply( + lambda x: x, kwargs={"a": 1}, engine="numba", raw=True + ) with pytest.raises( NumbaUtilError, match="numba does not support keyword-only arguments" ): Series(range(1)).rolling(1).apply( - lambda x: x, kwargs={"a": 1}, engine="numba", raw=True + lambda x, *, a: x, kwargs={"a": 1}, engine="numba", raw=True ) + tm.assert_series_equal( + Series(range(1), dtype=float) + 1, + Series(range(1)) + .rolling(1) + .apply(lambda x, a: (x + a).sum(), kwargs={"a": 1}, engine="numba", raw=True), + ) + @td.skip_if_no("numba") @pytest.mark.slow