Skip to content

Commit

Permalink
BUG: rolling with datetime ArrowDtype (pandas-dev#56370)
Browse files Browse the repository at this point in the history
* BUG: rolling with datetime ArrowDtype

* Dont modify needs_i8_conversion

* More explicit tests

* Fix arrow to_numpy
  • Loading branch information
mroeschke authored and cbpygit committed Jan 2, 2024
1 parent 2a474ea commit ab666f4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.resample` when resampling on a :class:`ArrowDtype` of ``pyarrow.timestamp`` or ``pyarrow.duration`` type (:issue:`55989`)
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55281`)
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.MonthBegin` (:issue:`55271`)
- Bug in :meth:`DataFrame.rolling` and :meth:`Series.rolling` where either the ``index`` or ``on`` column was :class:`ArrowDtype` with ``pyarrow.timestamp`` type (:issue:`55849`)

Reshaping
^^^^^^^^^
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
CategoricalDtype,
DatetimeTZDtype,
ExtensionDtype,
Expand Down Expand Up @@ -2531,7 +2532,7 @@ def _validate_inferred_freq(
return freq


def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype | ArrowDtype) -> str:
"""
Return the unit str corresponding to the dtype's resolution.
Expand All @@ -2546,4 +2547,8 @@ def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
"""
if isinstance(dtype, DatetimeTZDtype):
return dtype.unit
elif isinstance(dtype, ArrowDtype):
if dtype.kind not in "mM":
raise ValueError(f"{dtype=} does not have a resolution.")
return dtype.pyarrow_dtype.unit
return np.datetime_data(dtype)[0]
23 changes: 14 additions & 9 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Any,
Callable,
Literal,
cast,
)

import numpy as np
Expand All @@ -39,6 +38,7 @@
is_numeric_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ArrowDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -104,6 +104,7 @@
NDFrameT,
QuantileInterpolation,
WindowingRankType,
npt,
)

from pandas import (
Expand Down Expand Up @@ -404,11 +405,12 @@ def _insert_on_column(self, result: DataFrame, obj: DataFrame) -> None:
result[name] = extra_col

@property
def _index_array(self):
def _index_array(self) -> npt.NDArray[np.int64] | None:
# TODO: why do we get here with e.g. MultiIndex?
if needs_i8_conversion(self._on.dtype):
idx = cast("PeriodIndex | DatetimeIndex | TimedeltaIndex", self._on)
return idx.asi8
if isinstance(self._on, (PeriodIndex, DatetimeIndex, TimedeltaIndex)):
return self._on.asi8
elif isinstance(self._on.dtype, ArrowDtype) and self._on.dtype.kind in "mM":
return self._on.to_numpy(dtype=np.int64)
return None

def _resolve_output(self, out: DataFrame, obj: DataFrame) -> DataFrame:
Expand Down Expand Up @@ -439,7 +441,7 @@ def _apply_series(
self, homogeneous_func: Callable[..., ArrayLike], name: str | None = None
) -> Series:
"""
Series version of _apply_blockwise
Series version of _apply_columnwise
"""
obj = self._create_data(self._selected_obj)

Expand All @@ -455,7 +457,7 @@ def _apply_series(
index = self._slice_axis_for_step(obj.index, result)
return obj._constructor(result, index=index, name=obj.name)

def _apply_blockwise(
def _apply_columnwise(
self,
homogeneous_func: Callable[..., ArrayLike],
name: str,
Expand Down Expand Up @@ -614,7 +616,7 @@ def calc(x):
return result

if self.method == "single":
return self._apply_blockwise(homogeneous_func, name, numeric_only)
return self._apply_columnwise(homogeneous_func, name, numeric_only)
else:
return self._apply_tablewise(homogeneous_func, name, numeric_only)

Expand Down Expand Up @@ -1232,7 +1234,9 @@ def calc(x):

return result

return self._apply_blockwise(homogeneous_func, name, numeric_only)[:: self.step]
return self._apply_columnwise(homogeneous_func, name, numeric_only)[
:: self.step
]

@doc(
_shared_docs["aggregate"],
Expand Down Expand Up @@ -1868,6 +1872,7 @@ def _validate(self) -> None:
if (
self.obj.empty
or isinstance(self._on, (DatetimeIndex, TimedeltaIndex, PeriodIndex))
or (isinstance(self._on.dtype, ArrowDtype) and self._on.dtype.kind in "mM")
) and isinstance(self.window, (str, BaseOffset, timedelta)):
self._validate_datetimelike_monotonic()

Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/window/test_timeseries_window.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import (
DataFrame,
DatetimeIndex,
Index,
MultiIndex,
NaT,
Series,
Expand Down Expand Up @@ -697,3 +700,16 @@ def test_nat_axis_error(msg, axis):
with pytest.raises(ValueError, match=f"{msg} values must not have NaT"):
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
df.rolling("D", axis=axis).mean()


@td.skip_if_no("pyarrow")
def test_arrow_datetime_axis():
# GH 55849
expected = Series(
np.arange(5, dtype=np.float64),
index=Index(
date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]"
),
)
result = expected.rolling("1D").sum()
tm.assert_series_equal(result, expected)

0 comments on commit ab666f4

Please sign in to comment.