Skip to content

Commit

Permalink
update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Dec 31, 2024
1 parent 8887aa5 commit 6c7d53e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
25 changes: 9 additions & 16 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import datetime
import functools
import math
from typing import TYPE_CHECKING, cast

import numpy as np
Expand Down Expand Up @@ -39,20 +40,6 @@
"D": 86_400_000_000_000,
}

_dtype_total_seconds_factor = {
np.dtype("timedelta64[s]"): 1.0,
np.dtype("timedelta64[ms]"): 1e-3,
np.dtype("timedelta64[us]"): 1e-6,
np.dtype("timedelta64[ns]"): 1e-9,
}

_dtype_total_seconds_decimal_round = {
np.dtype("timedelta64[s]"): 1,
np.dtype("timedelta64[ms]"): 3,
np.dtype("timedelta64[us]"): 6,
np.dtype("timedelta64[ns]"): 9,
}


class TimeDeltaColumn(ColumnBase):
"""
Expand Down Expand Up @@ -277,9 +264,15 @@ def time_unit(self) -> str:
return np.datetime_data(self.dtype)[0]

def total_seconds(self) -> ColumnBase:
conversion = _unit_to_nanoseconds_conversion[self.time_unit] / 1e9
# Typecast to decimal128 to avoid floating point precision issues
# https://github.com/rapidsai/cudf/issues/17664
return (
self.astype("int64") * _dtype_total_seconds_factor[self.dtype]
).round(decimals=_dtype_total_seconds_decimal_round[self.dtype])
(self.astype("int64") * conversion)
.astype(cudf.Decimal128Dtype(38, 9))
.round(decimals=abs(int(math.log10(conversion))))
.astype("float64")
)

def ceil(self, freq: str) -> ColumnBase:
raise NotImplementedError("ceil is currently not implemented")
Expand Down
6 changes: 0 additions & 6 deletions python/cudf/cudf/tests/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,12 +1522,6 @@ def test_timedelta_series_total_seconds(data, dtype):
@pytest.mark.parametrize("data", _TIMEDELTA_DATA)
@pytest.mark.parametrize("dtype", utils.TIMEDELTA_TYPES)
def test_timedelta_index_total_seconds(request, data, dtype):
request.applymarker(
pytest.mark.xfail(
condition=(1132.324 in data and dtype == "timedelta64[ms]"),
reason="https://github.com/rapidsai/cudf/issues/17664",
)
)
gi = cudf.Index(data, dtype=dtype)
pi = gi.to_pandas()

Expand Down

0 comments on commit 6c7d53e

Please sign in to comment.