Skip to content

Commit

Permalink
BUG: numba raises for string columns or index
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Nov 26, 2023
1 parent 762b61d commit dab87c7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ Conversion
Strings
^^^^^^^
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
-

Interval
^^^^^^^^
Expand Down
12 changes: 9 additions & 3 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,11 +1172,17 @@ def apply_with_numba(self) -> dict[int, Any]:
)
from pandas.core._numba.extensions import set_numba_data

index = self.obj.index
if index.dtype == "string":
index = index.astype(object)

columns = self.obj.columns
if columns.dtype == "string":
columns = columns.astype(object)

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(self.obj.index) as index, set_numba_data(
self.columns
) as columns:
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index))
return res

Expand Down
19 changes: 18 additions & 1 deletion pandas/tests/apply/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ def test_numba_vs_python_noop(float_frame, apply_axis):
tm.assert_frame_equal(result, expected)


def test_numba_vs_python_string_index():
# GH#56189
pytest.importorskip("pyarrow")
df = DataFrame(
1,
index=Index(["a", "b"], dtype="string[pyarrow_numpy]"),
columns=Index(["x", "y"], dtype="string[pyarrow_numpy]"),
)
func = lambda x: x
result = df.apply(func, engine="numba", axis=0)
expected = df.apply(func, engine="python", axis=0)
tm.assert_frame_equal(
result, expected, check_column_type=False, check_index_type=False
)


def test_numba_vs_python_indexing():
frame = DataFrame(
{"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},
Expand Down Expand Up @@ -88,7 +104,8 @@ def test_numba_unsupported_dtypes(apply_axis):
df["c"] = df["c"].astype("double[pyarrow]")

with pytest.raises(
ValueError, match="Column b must have a numeric dtype. Found 'object' instead"
ValueError,
match="Column b must have a numeric dtype. Found 'object|string' instead",
):
df.apply(f, engine="numba", axis=apply_axis)

Expand Down

0 comments on commit dab87c7

Please sign in to comment.