Skip to content

Commit

Permalink
ENH: numba apply supports positional arguments passed as **kwargs (#5…
Browse files Browse the repository at this point in the history
…8995)

* add *args for raw numba apply

* add whatsnew

* fix test_case

* fix pre-commit

* fix test case

* add *args for raw=False as well; merge tests together

* add prepare_function_arguments

* fix mypy

* update get_jit_arguments

* add nopython test in `test_apply_args`

* fix test

* fix pre-commit

* modify prepare_function_arguments

* add tests

* add tests

* add whatsnew

* compat for python 3.12

* pre-commit

* compat for python 3.12

* update doc; use kw-only

* add more tests

* update whatsnew

* pre-commit

* move the tests to test_numba.py

* Update doc/source/whatsnew/v3.0.0.rst

Co-authored-by: Matthew Roeschke <[email protected]>

* Update doc/source/whatsnew/v3.0.0.rst

Co-authored-by: Matthew Roeschke <[email protected]>

---------

Co-authored-by: Matthew Roeschke <[email protected]>
  • Loading branch information
auderson and mroeschke authored Oct 31, 2024
1 parent 13926e5 commit 8be2f8b
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 38 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Other enhancements
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)
Expand Down
15 changes: 10 additions & 5 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,14 +994,15 @@ def wrapper(*args, **kwargs):
self.func, # type: ignore[arg-type]
self.args,
self.kwargs,
num_required_args=1,
)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**get_jit_arguments(engine_kwargs, kwargs),
**get_jit_arguments(engine_kwargs),
)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
Expand Down Expand Up @@ -1158,9 +1159,11 @@ def numba_func(values, col_names, df_index, *args):

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
args, kwargs = prepare_function_arguments(
func, self.args, self.kwargs, num_required_args=1
)
nb_func = self.generate_numba_apply_func(
func, **get_jit_arguments(self.engine_kwargs, kwargs)
func, **get_jit_arguments(self.engine_kwargs)
)
from pandas.core._numba.extensions import set_numba_data

Expand Down Expand Up @@ -1298,9 +1301,11 @@ def numba_func(values, col_names_index, index, *args):

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
args, kwargs = prepare_function_arguments(
func, self.args, self.kwargs, num_required_args=1
)
nb_func = self.generate_numba_apply_func(
func, **get_jit_arguments(self.engine_kwargs, kwargs)
func, **get_jit_arguments(self.engine_kwargs)
)

from pandas.core._numba.extensions import set_numba_data
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class providing the base-class of operations.
from pandas.core.util.numba_ import (
get_jit_arguments,
maybe_use_numba,
prepare_function_arguments,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1289,8 +1290,11 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=2
)
numba_transform_func = numba_.generate_numba_transform_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
result = numba_transform_func(
sorted_data,
Expand Down Expand Up @@ -1325,8 +1329,11 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=2
)
numba_agg_func = numba_.generate_numba_agg_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
result = numba_agg_func(
sorted_data,
Expand Down
47 changes: 24 additions & 23 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@ def set_use_numba(enable: bool = False) -> None:
GLOBAL_USE_NUMBA = enable


def get_jit_arguments(
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
) -> dict[str, bool]:
def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Parameters
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function
Returns
-------
Expand All @@ -55,16 +51,6 @@ def get_jit_arguments(
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
Expand Down Expand Up @@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable:


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
func: Callable, args: tuple, kwargs: dict, *, num_required_args: int
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
Expand All @@ -118,11 +104,17 @@ def prepare_function_arguments(
Parameters
----------
func : function
user defined function
User defined function
args : tuple
user input positional arguments
User input positional arguments
kwargs : dict
user input keyword arguments
User input keyword arguments
num_required_args : int
The number of leading positional arguments we will pass to udf.
These are not supplied by the user.
e.g. for groupby we require "values", "index" as the first two arguments:
`numba_func(group, group_index, *args)`, in this case num_required_args=2.
See :func:`pandas.core.groupby.numba_.generate_numba_agg_func`
Returns
-------
Expand All @@ -133,17 +125,26 @@ def prepare_function_arguments(
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)

args = args[num_required_args:]
return args, kwargs
9 changes: 6 additions & 3 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pandas.core.util.numba_ import (
get_jit_arguments,
maybe_use_numba,
prepare_function_arguments,
)
from pandas.core.window.common import (
flex_binary_moment,
Expand Down Expand Up @@ -1472,14 +1473,16 @@ def apply(
if maybe_use_numba(engine):
if raw is False:
raise ValueError("raw must be `True` when using the numba engine")
numba_args = args
numba_args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=1
)
if self.method == "single":
apply_func = generate_numba_apply_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
else:
apply_func = generate_numba_table_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
elif engine in ("cython", None):
if engine_kwargs is not None:
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython):
tm.assert_frame_equal(result, expected)

if engine == "numba":
# py signature binding
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
float_frame.apply(
lambda x, a: x + a,
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)

# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
Expand Down
29 changes: 27 additions & 2 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,43 @@ def incorrect_function(x):
def test_check_nopython_kwargs():
pytest.importorskip("numba")

def incorrect_function(values, index):
return sum(values) * 2.7
def incorrect_function(values, index, *, a):
return sum(values) * 2.7 + a

def correct_function(values, index, a):
return sum(values) * 2.7 + a

data = DataFrame(
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
columns=["key", "data"],
)
expected = data.groupby("key").sum() * 2.7

# py signature binding
with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key").agg(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key").agg(correct_function, engine="numba", b=1)

with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key")["data"].agg(correct_function, engine="numba", b=1)

# numba signature check after binding
with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
actual = data.groupby("key").agg(correct_function, engine="numba", a=1)
tm.assert_frame_equal(expected + 1, actual)

with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
tm.assert_series_equal(expected["data"] + 1, actual)


@pytest.mark.filterwarnings("ignore")
Expand Down
29 changes: 27 additions & 2 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,43 @@ def incorrect_function(x):
def test_check_nopython_kwargs():
pytest.importorskip("numba")

def incorrect_function(values, index):
return values + 1
def incorrect_function(values, index, *, a):
return values + a

def correct_function(values, index, a):
return values + a

data = DataFrame(
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
columns=["key", "data"],
)
# py signature binding
with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key").transform(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key").transform(correct_function, engine="numba", b=1)

with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key")["data"].transform(correct_function, engine="numba", b=1)

# numba signature check after binding
with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key").transform(incorrect_function, engine="numba", a=1)
actual = data.groupby("key").transform(correct_function, engine="numba", a=1)
tm.assert_frame_equal(data[["data"]] + 1, actual)

with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
actual = data.groupby("key")["data"].transform(
correct_function, engine="numba", a=1
)
tm.assert_series_equal(data["data"] + 1, actual)


@pytest.mark.filterwarnings("ignore")
Expand Down
Loading

0 comments on commit 8be2f8b

Please sign in to comment.