Skip to content

Commit

Permalink
ENH: TimedeltaArray add/sub with NaT preserve reso (#47522)
Browse files Browse the repository at this point in the history
* ENH: TimedeltaArray add/sub with NaT preserve reso

* mypy fixup

* use datetime_data
  • Loading branch information
jbrockmendel authored Jun 28, 2022
1 parent 7dad4e7 commit 612f566
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 26 deletions.
43 changes: 33 additions & 10 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def _cmp_method(self, other, op):
__rdivmod__ = make_invalid_op("__rdivmod__")

@final
def _add_datetimelike_scalar(self, other):
def _add_datetimelike_scalar(self, other) -> DatetimeArray:
if not is_timedelta64_dtype(self.dtype):
raise TypeError(
f"cannot add {type(self).__name__} and {type(other).__name__}"
Expand All @@ -1102,16 +1102,12 @@ def _add_datetimelike_scalar(self, other):
if other is NaT:
# In this case we specifically interpret NaT as a datetime, not
# the timedelta interpretation we would get by returning self + NaT
result = self.asi8.view("m8[ms]") + NaT.to_datetime64()
return DatetimeArray(result)
result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]")
# Preserve our resolution
return DatetimeArray._simple_new(result, dtype=result.dtype)

i8 = self.asi8
# Incompatible types in assignment (expression has type "ndarray[Any,
# dtype[signedinteger[_64Bit]]]", variable has type
# "ndarray[Any, dtype[datetime64]]")
result = checked_add_with_arr( # type: ignore[assignment]
i8, other.value, arr_mask=self._isnan
)
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
dtype = DatetimeTZDtype(tz=other.tz) if other.tz else DT64NS_DTYPE
return DatetimeArray(result, dtype=dtype, freq=self.freq)

Expand Down Expand Up @@ -1275,12 +1271,14 @@ def _add_nat(self):
raise TypeError(
f"Cannot add {type(self).__name__} and {type(NaT).__name__}"
)
self = cast("TimedeltaArray | DatetimeArray", self)

# GH#19124 pd.NaT is treated like a timedelta for both timedelta
# and datetime dtypes
result = np.empty(self.shape, dtype=np.int64)
result.fill(iNaT)
return type(self)(result, dtype=self.dtype, freq=None)
result = result.view(self._ndarray.dtype) # preserve reso
return type(self)._simple_new(result, dtype=self.dtype, freq=None)

@final
def _sub_nat(self):
Expand Down Expand Up @@ -1905,6 +1903,13 @@ class TimelikeOps(DatetimeLikeArrayMixin):
def _reso(self) -> int:
return get_unit_from_dtype(self._ndarray.dtype)

@cache_readonly
def _unit(self) -> str:
# e.g. "ns", "us", "ms"
# error: Argument 1 to "dtype_to_unit" has incompatible type
# "ExtensionDtype"; expected "Union[DatetimeTZDtype, dtype[Any]]"
return dtype_to_unit(self.dtype) # type: ignore[arg-type]

def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
if (
ufunc in [np.isnan, np.isinf, np.isfinite]
Expand Down Expand Up @@ -2105,3 +2110,21 @@ def maybe_infer_freq(freq):
freq_infer = True
freq = None
return freq, freq_infer


def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
"""
Return the unit str corresponding to the dtype's resolution.
Parameters
----------
dtype : DatetimeTZDtype or np.dtype
If np.dtype, we assume it is a datetime64 dtype.
Returns
-------
str
"""
if isinstance(dtype, DatetimeTZDtype):
return dtype.unit
return np.datetime_data(dtype)[0]
56 changes: 40 additions & 16 deletions pandas/tests/arrays/test_timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import pandas as pd
from pandas import Timedelta
import pandas._testing as tm
from pandas.core.arrays import TimedeltaArray
from pandas.core.arrays import (
DatetimeArray,
TimedeltaArray,
)


class TestNonNano:
Expand All @@ -25,6 +28,11 @@ def reso(self, unit):
else:
raise NotImplementedError(unit)

@pytest.fixture
def tda(self, unit):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
return TimedeltaArray._simple_new(arr, dtype=arr.dtype)

def test_non_nano(self, unit, reso):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
Expand All @@ -33,39 +41,55 @@ def test_non_nano(self, unit, reso):
assert tda[0]._reso == reso

@pytest.mark.parametrize("field", TimedeltaArray._field_ops)
def test_fields(self, unit, field):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)

as_nano = arr.astype("m8[ns]")
def test_fields(self, tda, field):
as_nano = tda._ndarray.astype("m8[ns]")
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)

result = getattr(tda, field)
expected = getattr(tda_nano, field)
tm.assert_numpy_array_equal(result, expected)

def test_to_pytimedelta(self, unit):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)

as_nano = arr.astype("m8[ns]")
def test_to_pytimedelta(self, tda):
as_nano = tda._ndarray.astype("m8[ns]")
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)

result = tda.to_pytimedelta()
expected = tda_nano.to_pytimedelta()
tm.assert_numpy_array_equal(result, expected)

def test_total_seconds(self, unit):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)

as_nano = arr.astype("m8[ns]")
def test_total_seconds(self, unit, tda):
as_nano = tda._ndarray.astype("m8[ns]")
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)

result = tda.total_seconds()
expected = tda_nano.total_seconds()
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize(
"nat", [np.datetime64("NaT", "ns"), np.datetime64("NaT", "us")]
)
def test_add_nat_datetimelike_scalar(self, nat, tda):
result = tda + nat
assert isinstance(result, DatetimeArray)
assert result._reso == tda._reso
assert result.isna().all()

result = nat + tda
assert isinstance(result, DatetimeArray)
assert result._reso == tda._reso
assert result.isna().all()

def test_add_pdnat(self, tda):
result = tda + pd.NaT
assert isinstance(result, TimedeltaArray)
assert result._reso == tda._reso
assert result.isna().all()

result = pd.NaT + tda
assert isinstance(result, TimedeltaArray)
assert result._reso == tda._reso
assert result.isna().all()


class TestTimedeltaArray:
@pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])
Expand Down

0 comments on commit 612f566

Please sign in to comment.