Skip to content

Commit

Permalink
Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rlskoeser committed Aug 16, 2024
1 parent 1a75640 commit 329fa3d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 32 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ requires = [
"wheel"
]
build-backend = "setuptools.build_meta"

[tool.mypy]
plugins = ["numpy.typing.mypy_plugin"]
22 changes: 13 additions & 9 deletions src/undate/date.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from enum import IntEnum

# Pre 3.10 requires Union for multiple types, e.g. Union[int, None] instead of int | None
from typing import Optional, Dict, Union


import numpy as np

#: timedelta for single day
Expand All @@ -17,26 +21,26 @@ class Date(np.ndarray):
# extend np.datetime64 datatype
# adapted from https://stackoverflow.com/a/27129510/9706217

def __new__(cls, year: int, month: int = None, day: int = None):
def __new__(cls, year: int, month: Optional[int] = None, day: Optional[int] = None):
if isinstance(year, np.datetime64):
data = year
_data = year
else:
datestr = str(year)
if month is not None:
datestr = f"{year}-{month:02d}"
if day is not None:
datestr = f"{datestr}-{day:02d}"
data = np.datetime64(datestr)
_data = np.datetime64(datestr)

data = np.asarray(data, dtype="datetime64")
data = np.asarray(_data, dtype="datetime64")

# expected format depends on granularity / how much of date is known
expected_granularity = "Y"
# expected dtype depends on date unit / how much of date is known
expected_unit = "Y"
if day is not None and month is not None:
expected_granularity = "D"
expected_unit = "D"
elif month:
expected_granularity = "M"
expected_dtype = f"datetime64[{expected_granularity}]"
expected_unit = "M"
expected_dtype = f"datetime64[{expected_unit}]"

if data.dtype != expected_dtype:
raise Exception(
Expand Down
45 changes: 22 additions & 23 deletions src/undate/undate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from calendar import monthrange

# Pre 3.10 requires Union for multiple types, e.g. Union[int, None] instead of int | None
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Any

import numpy as np
from numpy.typing import ArrayLike, DTypeLike

from undate.date import Date, DatePrecision, ONE_DAY, ONE_YEAR, ONE_MONTH_MAX
from undate.dateformat.base import BaseDateFormat
Expand All @@ -17,8 +20,8 @@ class Undate:
#: symbol for unknown digits within a date value
MISSING_DIGIT: str = "X"

earliest: datetime.date
latest: datetime.date
earliest: Date
latest: Date
#: A string to label a specific undate, e.g. "German Unity Date 2022" for Oct. 3, 2022.
#: Labels are not taken into account when comparing undate objects.
label: Union[str, None] = None
Expand Down Expand Up @@ -64,19 +67,17 @@ def __init__(
min_year = int(str(year).replace(self.MISSING_DIGIT, "0"))
max_year = int(str(year).replace(self.MISSING_DIGIT, "9"))
else:
# min_year = datetime.MINYEAR
# max_year = datetime.MAXYEAR
# numpy datetime is stored as 64-bit integer, so length
# depends on the span; assume days for now

# numpy datetime is stored as 64-bit integer, so min/max
# depends on the time unit; assume days for now
# See https://numpy.org/doc/stable/reference/arrays.datetime.html#datetime-units
max_year = int(2.5e16)
min_year = int(-2.5e16)

# if month is passed in as a string but completely unknown,
# treat as none
# TODO: we should preserve this information somehow;
# difference between just a year and and an unknown month within a year
# maybe in terms of granularity / size ?
# maybe in terms of date precision ?
if month == "XX":
month = None

Expand Down Expand Up @@ -124,9 +125,6 @@ def __init__(

# for unknowns, assume smallest possible value for earliest and
# largest valid for latest
# self.earliest = datetime.date(min_year, min_month, min_day)
# self.latest = datetime.date(max_year, max_month, max_day)

self.earliest = Date(min_year, min_month, min_day)
self.latest = Date(max_year, max_month, max_day)

Expand Down Expand Up @@ -245,7 +243,7 @@ def __gt__(self, other: object) -> bool:
# strictly greater than must rule out equals
return not (self < other or self == other)

def __le__(self, other: Union["Undate", datetime.date]) -> bool:
def __le__(self, other: object) -> bool:
return self == other or self < other

def __contains__(self, other: object) -> bool:
Expand All @@ -256,15 +254,17 @@ def __contains__(self, other: object) -> bool:
if self == other:
return False

return (
self.earliest <= other.earliest
and self.latest >= other.latest
# is precision sufficient for comparing partially known dates?
and self.precision > other.precision
return all(
[
self.earliest <= other.earliest,
self.latest >= other.latest,
# is precision sufficient for comparing partially known dates?
self.precision > other.precision,
]
)

@staticmethod
def from_datetime_date(dt_date):
def from_datetime_date(dt_date: datetime.date):
"""Initialize an :class:`Undate` object from a :class:`datetime.date`"""
return Undate(dt_date.year, dt_date.month, dt_date.day)

Expand All @@ -284,7 +284,7 @@ def is_known(self, part: str) -> bool:
def is_partially_known(self, part: str) -> bool:
return isinstance(self.initial_values[part], str)

def duration(self) -> datetime.timedelta:
def duration(self): # -> np.timedelta64:
"""What is the duration of this date?
Calculate based on earliest and latest date within range,
taking into account the precision of the date even if not all
Expand Down Expand Up @@ -313,7 +313,6 @@ def duration(self) -> datetime.timedelta:

# if granularity == month but not known month, duration = 31
if delta.astype(int) > 31:
# return datetime.timedelta(days=31)
return ONE_MONTH_MAX
return delta

Expand Down Expand Up @@ -394,11 +393,11 @@ def __eq__(self, other) -> bool:
# consider interval equal if both dates are equal
return self.earliest == other.earliest and self.latest == other.latest

def duration(self) -> datetime.timedelta:
def duration(self): # -> np.timedelta64:
"""Calculate the duration between two undates.
:returns: A duration
:rtype: timedelta
:rtype: numpy.timedelta64
"""
# what is the duration of this date range?

Expand Down

0 comments on commit 329fa3d

Please sign in to comment.