Skip to content

Commit

Permalink
ENH: Add numba engine to df.apply (#55104)
Browse files Browse the repository at this point in the history
* ENH: Add numba engine to df.apply

* complete?

* wip: pass tests

* fix existing tests

* go for green

* fix checks?

* fix pyright

* update docs

* eliminate a blank line

* update from code review + more tests

* fix failing tests

* Simplify w/ context manager

* skip if no numba

* simplify more

* specify dtypes

* address code review

* add errors for invalid columns

* adjust message
  • Loading branch information
lithomas1 authored Oct 22, 2023
1 parent 206f981 commit ac5587c
Show file tree
Hide file tree
Showing 9 changed files with 948 additions and 67 deletions.
575 changes: 575 additions & 0 deletions pandas/core/_numba/extensions.py

Large diffs are not rendered by default.

184 changes: 175 additions & 9 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
from collections import defaultdict
import functools
from functools import partial
import inspect
from typing import (
Expand Down Expand Up @@ -29,14 +30,17 @@
NDFrameT,
npt,
)
from pandas.compat._optional import import_optional_dependency
from pandas.errors import SpecificationError
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.cast import is_nested_object
from pandas.core.dtypes.common import (
is_dict_like,
is_extension_array_dtype,
is_list_like,
is_numeric_dtype,
is_sequence,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -121,6 +125,8 @@ def __init__(
result_type: str | None,
*,
by_row: Literal[False, "compat", "_compat"] = "compat",
engine: str = "python",
engine_kwargs: dict[str, bool] | None = None,
args,
kwargs,
) -> None:
Expand All @@ -133,6 +139,9 @@ def __init__(
self.args = args or ()
self.kwargs = kwargs or {}

self.engine = engine
self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs

if result_type not in [None, "reduce", "broadcast", "expand"]:
raise ValueError(
"invalid value for result_type, must be one "
Expand Down Expand Up @@ -601,6 +610,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
result: Series, DataFrame, or None
Result when self.func is a list-like or dict-like, None otherwise.
"""

if self.engine == "numba":
raise NotImplementedError(
"The 'numba' engine doesn't support list-like/"
"dict likes of callables yet."
)

if self.axis == 1 and isinstance(self.obj, ABCDataFrame):
return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T

Expand Down Expand Up @@ -768,10 +784,16 @@ def __init__(
) -> None:
if by_row is not False and by_row != "compat":
raise ValueError(f"by_row={by_row} not allowed")
self.engine = engine
self.engine_kwargs = engine_kwargs
super().__init__(
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
obj,
func,
raw,
result_type,
by_row=by_row,
engine=engine,
engine_kwargs=engine_kwargs,
args=args,
kwargs=kwargs,
)

# ---------------------------------------------------------------
Expand All @@ -792,6 +814,32 @@ def result_columns(self) -> Index:
def series_generator(self) -> Generator[Series, None, None]:
pass

@staticmethod
@functools.cache
@abc.abstractmethod
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
pass

@abc.abstractmethod
def apply_with_numba(self):
pass

def validate_values_for_numba(self):
# Validate column dtyps all OK
for colname, dtype in self.obj.dtypes.items():
if not is_numeric_dtype(dtype):
raise ValueError(
f"Column {colname} must have a numeric dtype. "
f"Found '{dtype}' instead"
)
if is_extension_array_dtype(dtype):
raise ValueError(
f"Column {colname} is backed by an extension array, "
f"which is not supported by the numba engine."
)

@abc.abstractmethod
def wrap_results_for_axis(
self, results: ResType, res_index: Index
Expand All @@ -815,13 +863,12 @@ def values(self):
def apply(self) -> DataFrame | Series:
"""compute the results"""

if self.engine == "numba" and not self.raw:
raise ValueError(
"The numba engine in DataFrame.apply can only be used when raw=True"
)

# dispatch to handle list-like or dict-like
if is_list_like(self.func):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support lists of callables yet"
)
return self.apply_list_or_dict_like()

# all empty
Expand All @@ -830,17 +877,31 @@ def apply(self) -> DataFrame | Series:

# string dispatch
if isinstance(self.func, str):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support using "
"a string as the callable function"
)
return self.apply_str()

# ufunc
elif isinstance(self.func, np.ufunc):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support "
"using a numpy ufunc as the callable function"
)
with np.errstate(all="ignore"):
results = self.obj._mgr.apply("apply", func=self.func)
# _constructor will retain self.index and self.columns
return self.obj._constructor_from_mgr(results, axes=results.axes)

# broadcasting
if self.result_type == "broadcast":
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support result_type='broadcast'"
)
return self.apply_broadcast(self.obj)

# one axis empty
Expand Down Expand Up @@ -997,7 +1058,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
return result

def apply_standard(self):
results, res_index = self.apply_series_generator()
if self.engine == "python":
results, res_index = self.apply_series_generator()
else:
results, res_index = self.apply_series_numba()

# wrap results
return self.wrap_results(results, res_index)
Expand All @@ -1021,6 +1085,19 @@ def apply_series_generator(self) -> tuple[ResType, Index]:

return results, res_index

def apply_series_numba(self):
if self.engine_kwargs.get("parallel", False):
raise NotImplementedError(
"Parallel apply is not supported when raw=False and engine='numba'"
)
if not self.obj.index.is_unique or not self.columns.is_unique:
raise NotImplementedError(
"The index/columns must be unique when raw=False and engine='numba'"
)
self.validate_values_for_numba()
results = self.apply_with_numba()
return results, self.result_index

def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series:
from pandas import Series

Expand Down Expand Up @@ -1060,6 +1137,49 @@ class FrameRowApply(FrameApply):
def series_generator(self) -> Generator[Series, None, None]:
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
numba = import_optional_dependency("numba")
from pandas import Series

# Import helper from extensions to cast string object -> np strings
# Note: This also has the side effect of loading our numba extensions
from pandas.core._numba.extensions import maybe_cast_str

jitted_udf = numba.extending.register_jitable(func)

# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
results = {}
for j in range(values.shape[1]):
# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)
from pandas.core._numba.extensions import set_numba_data

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(self.obj.index) as index, set_numba_data(
self.columns
) as columns:
res = dict(nb_func(self.values, columns, index))
return res

@property
def result_index(self) -> Index:
return self.columns
Expand Down Expand Up @@ -1143,6 +1263,52 @@ def series_generator(self) -> Generator[Series, None, None]:
object.__setattr__(ser, "_name", name)
yield ser

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
numba = import_optional_dependency("numba")
from pandas import Series
from pandas.core._numba.extensions import maybe_cast_str

jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
for i in range(values.shape[0]):
# Create the series
# TODO: values corrupted without the copy
ser = Series(
values[i].copy(),
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)

from pandas.core._numba.extensions import set_numba_data

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(self.obj.index) as index, set_numba_data(
self.columns
) as columns:
res = dict(nb_func(self.values, columns, index))

return res

@property
def result_index(self) -> Index:
return self.index
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10090,6 +10090,9 @@ def apply(
- nogil (release the GIL inside the JIT compiled function)
- parallel (try to apply the function in parallel over the DataFrame)
Note: Due to limitations within numba/how pandas interfaces with numba,
you should only use this if raw=True
Note: The numba compiler only supports a subset of
valid Python/numpy operations.
Expand All @@ -10099,8 +10102,6 @@ def apply(
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
in numba to learn what you can or cannot use in the passed function.
As of right now, the numba engine can only be used with raw=True.
.. versionadded:: 2.2.0
engine_kwargs : dict
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/apply/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ def int_frame_const_col():
columns=["A", "B", "C"],
)
return df


@pytest.fixture(params=["python", "numba"])
def engine(request):
if request.param == "numba":
pytest.importorskip("numba")
return request.param


@pytest.fixture(params=[0, 1])
def apply_axis(request):
return request.param
Loading

0 comments on commit ac5587c

Please sign in to comment.