Skip to content

Commit

Permalink
Enable passing subdiv to holidays
Browse files Browse the repository at this point in the history
- England now needs to be passed as country=UK, subdiv=England
  • Loading branch information
j0nnyr0berts committed Oct 7, 2022
1 parent 9aff502 commit 2476542
Showing 1 changed file with 61 additions and 65 deletions.
126 changes: 61 additions & 65 deletions timeserio/preprocessing/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
from typing import List, Union
from typing import List, Union, Optional

from .. import ini
from .utils import _as_list_of_str, CallableMixin
Expand All @@ -16,9 +16,9 @@
SECONDS_IN_MINUTE = 60
MINUTES_IN_DAY = MINUTES_IN_HOUR * HOURS_IN_DAY

PEAK_INTERVAL = ('16:00', '19:00')
DAY_INTERVAL = ('06:00', '23:00')
MORNING_INTERVAL = ('5:00', '10:00')
PEAK_INTERVAL = ("16:00", "19:00")
DAY_INTERVAL = ("06:00", "23:00")
MORNING_INTERVAL = ("5:00", "10:00")


def get_fractional_hour_from_series(series: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -52,7 +52,9 @@ def get_fractional_year_from_series(series: pd.Series) -> pd.Series:


def get_is_holiday_from_series(
series: pd.Series, country: str = "UnitedKingdom"
series: pd.Series,
country: str = "UnitedKingdom",
subdiv: Optional[str] = None,
) -> pd.Series:
"""Return 1 if day is a public holiday.
Expand All @@ -61,7 +63,9 @@ def get_is_holiday_from_series(
supported countries.
"""
years = series.dt.year.unique()
holiday_dates = holidays.CountryHoliday(country, years=years)
holiday_dates = holidays.country_holidays(
country, subdiv=subdiv, years=years
)
return series.dt.date.isin(holiday_dates).astype(int)


Expand All @@ -77,10 +81,7 @@ def get_is_weekday_from_series(series: pd.Series) -> pd.Series:


def get_time_is_in_interval_from_series(
series: pd.Series,
*,
start_time,
end_time,
series: pd.Series, *, start_time, end_time,
) -> pd.Series:
"""Return if time of day is in a given time interval."""
if isinstance(start_time, str):
Expand All @@ -107,69 +108,69 @@ def get_time_is_in_interval_from_series(
def get_is_peak_hour_from_series(series: pd.Series) -> pd.Series:
"""Return if time in peak hour interval."""
interval = PEAK_INTERVAL
return get_time_is_in_interval_from_series(series,
start_time=interval[0],
end_time=interval[1])
return get_time_is_in_interval_from_series(
series, start_time=interval[0], end_time=interval[1]
)


def get_is_daytime_from_series(series: pd.Series) -> pd.Series:
"""Return if time in daytime interval."""
interval = DAY_INTERVAL
return get_time_is_in_interval_from_series(series,
start_time=interval[0],
end_time=interval[1])
return get_time_is_in_interval_from_series(
series, start_time=interval[0], end_time=interval[1]
)


def get_is_morning_peak_from_series(series: pd.Series) -> pd.Series:
"""Return if time in morning peak interval."""
interval = MORNING_INTERVAL
return get_time_is_in_interval_from_series(series,
start_time=interval[0],
end_time=interval[1])
return get_time_is_in_interval_from_series(
series, start_time=interval[0], end_time=interval[1]
)


def truncate_series(series: pd.Series, truncation_period: str) -> pd.Series:
return series.dt.to_period(truncation_period).dt.to_timestamp()


SUPPORTED_DATETIME_ATTRS = [
'time',
'hour',
'month',
'day',
'dayofweek',
'dayofyear',
'weekday_name',
"time",
"hour",
"month",
"day",
"dayofweek",
"dayofyear",
"weekday_name",
]

# Custom datetime featurizers such as daylight etc.
CUSTOM_ATTRIBUTES = {
'fractionalday': get_fractional_day_from_series,
'fractionalhour': get_fractional_hour_from_series,
'fractionalyear': get_fractional_year_from_series,
'month0': get_zero_indexed_month_from_series,
'is_holiday': get_is_holiday_from_series,
'is_weekday': get_is_weekday_from_series,
'is_in_interval': get_time_is_in_interval_from_series,
'is_peak': get_is_peak_hour_from_series,
'is_daytime': get_is_daytime_from_series,
'is_morningpeak': get_is_morning_peak_from_series,
'dt_truncated': truncate_series
"fractionalday": get_fractional_day_from_series,
"fractionalhour": get_fractional_hour_from_series,
"fractionalyear": get_fractional_year_from_series,
"month0": get_zero_indexed_month_from_series,
"is_holiday": get_is_holiday_from_series,
"is_weekday": get_is_weekday_from_series,
"is_in_interval": get_time_is_in_interval_from_series,
"is_peak": get_is_peak_hour_from_series,
"is_daytime": get_is_daytime_from_series,
"is_morningpeak": get_is_morning_peak_from_series,
"dt_truncated": truncate_series,
}


class PandasDateTimeFeaturizer(BaseEstimator, TransformerMixin, CallableMixin):
"""Featurize datetime column by adding specified attributes."""

valid_attributes = (
SUPPORTED_DATETIME_ATTRS + list(CUSTOM_ATTRIBUTES.keys())
valid_attributes = SUPPORTED_DATETIME_ATTRS + list(
CUSTOM_ATTRIBUTES.keys()
)

def __init__(
self,
column=ini.Columns.datetime,
attributes=['month0', 'dayofweek', 'fractionalday'],
kwargs=None
attributes=["month0", "dayofweek", "fractionalday"],
kwargs=None,
):
self.column = column
self.attributes = attributes
Expand All @@ -185,7 +186,7 @@ def fit(self, df, y=None):

def transform(self, df):
if not isinstance(df, pd.DataFrame):
raise TypeError('Input must be a DataFrame.')
raise TypeError("Input must be a DataFrame.")
df = df.copy()
column = df[self.column]
if isinstance(column, pd.DataFrame): # hierarchical index
Expand All @@ -204,7 +205,7 @@ def transform(self, df):
else:
raise KeyError(
f'Unknown attribute "{attr}"; '
'see `PandasDateTimeFeaturizer.valid_attributes`'
"see `PandasDateTimeFeaturizer.valid_attributes`"
)
return df

Expand All @@ -215,7 +216,7 @@ def required_columns(self):
def transformed_columns(self, input_columns):
input_columns = set(_as_list_of_str(input_columns))
if not self.required_columns <= input_columns:
raise ValueError(f'Required columns are {self.required_columns}')
raise ValueError(f"Required columns are {self.required_columns}")
return input_columns | set(self.attributes_)


Expand All @@ -227,12 +228,7 @@ class _BaseLagFeaturizer(BaseEstimator, TransformerMixin, CallableMixin):
duplicate_agg: str = "raise"

def __init__(
self,
datetime_column,
columns,
lags,
refit=True,
duplicate_agg="raise"
self, datetime_column, columns, lags, refit=True, duplicate_agg="raise"
):
self.datetime_column = datetime_column
self.columns = columns
Expand All @@ -241,15 +237,16 @@ def __init__(
self.duplicate_agg = duplicate_agg

def fit(self, df, y=None, **fit_params):
if hasattr(self, 'df_') and not self.refit:
if hasattr(self, "df_") and not self.refit:
return self
columns = _as_list_of_str(self.columns)
self.df_ = df.set_index(self.datetime_column)[columns]
if self.duplicate_agg == 'raise':
if self.duplicate_agg == "raise":
if any(self.df_.index.duplicated()):
raise ValueError(
"Input dataframe contains duplicate entries "
"with the same %s", self.datetime_column
"with the same %s",
self.datetime_column,
)
else:
self.df_ = self.df_.groupby(level=0).agg(self.duplicate_agg)
Expand All @@ -261,15 +258,18 @@ def _lag_df(self, lag):
raise NotImplementedError

def transform(self, df):
check_is_fitted(self, 'df_')
check_is_fitted(self, "df_")
lags = _as_list_of_str(self.lags)
df = df.copy().set_index(self.datetime_column)
for lag in lags:
lag_df = self._lag_df(lag).add_suffix(f"_{lag}")
df = pd.merge(
df, lag_df,
how="left", left_index=True, right_index=True,
suffixes=("", "")
df,
lag_df,
how="left",
left_index=True,
right_index=True,
suffixes=("", ""),
)
return df.reset_index()

Expand All @@ -281,7 +281,7 @@ def required_columns(self):
def transformed_columns(self, input_columns):
input_columns = set(_as_list_of_str(input_columns))
if not self.required_columns <= input_columns:
raise ValueError(f'Required columns are {self.required_columns}')
raise ValueError(f"Required columns are {self.required_columns}")
lags = _as_list_of_str(self.lags)
columns = _as_list_of_str(self.columns)
new_columns = [f"{col}_{lag}" for lag in lags for col in columns]
Expand Down Expand Up @@ -362,14 +362,10 @@ def __init__(
win_type: str = None,
closed: str = None,
refit: bool = True,
duplicate_agg: str = 'raise'
duplicate_agg: str = "raise",
):
super().__init__(
datetime_column,
columns,
windows,
refit,
duplicate_agg
datetime_column, columns, windows, refit, duplicate_agg
)
self.min_periods = min_periods
self.center = center
Expand All @@ -390,5 +386,5 @@ def _lag_df(self, lag):
min_periods=self.min_periods,
center=self.center,
win_type=self.win_type,
closed=self.closed
closed=self.closed,
).mean()

0 comments on commit 2476542

Please sign in to comment.