diff --git a/timeserio/preprocessing/datetime.py b/timeserio/preprocessing/datetime.py index 045f1cd..6d4a35a 100644 --- a/timeserio/preprocessing/datetime.py +++ b/timeserio/preprocessing/datetime.py @@ -3,7 +3,7 @@ import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.validation import check_is_fitted -from typing import List, Union +from typing import List, Union, Optional from .. import ini from .utils import _as_list_of_str, CallableMixin @@ -16,9 +16,9 @@ SECONDS_IN_MINUTE = 60 MINUTES_IN_DAY = MINUTES_IN_HOUR * HOURS_IN_DAY -PEAK_INTERVAL = ('16:00', '19:00') -DAY_INTERVAL = ('06:00', '23:00') -MORNING_INTERVAL = ('5:00', '10:00') +PEAK_INTERVAL = ("16:00", "19:00") +DAY_INTERVAL = ("06:00", "23:00") +MORNING_INTERVAL = ("5:00", "10:00") def get_fractional_hour_from_series(series: pd.Series) -> pd.Series: @@ -52,7 +52,9 @@ def get_fractional_year_from_series(series: pd.Series) -> pd.Series: def get_is_holiday_from_series( - series: pd.Series, country: str = "UnitedKingdom" + series: pd.Series, + country: str = "UnitedKingdom", + subdiv: Optional[str] = None, ) -> pd.Series: """Return 1 if day is a public holiday. @@ -61,7 +63,9 @@ def get_is_holiday_from_series( supported countries. """ years = series.dt.year.unique() - holiday_dates = holidays.CountryHoliday(country, years=years) + holiday_dates = holidays.country_holidays( + country, subdiv=subdiv, years=years + ) return series.dt.date.isin(holiday_dates).astype(int) @@ -77,10 +81,7 @@ def get_is_weekday_from_series(series: pd.Series) -> pd.Series: def get_time_is_in_interval_from_series( - series: pd.Series, - *, - start_time, - end_time, + series: pd.Series, *, start_time, end_time, ) -> pd.Series: """Return if time of day is in a given time interval.""" if isinstance(start_time, str): @@ -107,25 +108,25 @@ def get_time_is_in_interval_from_series( def get_is_peak_hour_from_series(series: pd.Series) -> pd.Series: """Return if time in peak hour interval.""" interval = PEAK_INTERVAL - return get_time_is_in_interval_from_series(series, - start_time=interval[0], - end_time=interval[1]) + return get_time_is_in_interval_from_series( + series, start_time=interval[0], end_time=interval[1] + ) def get_is_daytime_from_series(series: pd.Series) -> pd.Series: """Return if time in daytime interval.""" interval = DAY_INTERVAL - return get_time_is_in_interval_from_series(series, - start_time=interval[0], - end_time=interval[1]) + return get_time_is_in_interval_from_series( + series, start_time=interval[0], end_time=interval[1] + ) def get_is_morning_peak_from_series(series: pd.Series) -> pd.Series: """Return if time in morning peak interval.""" interval = MORNING_INTERVAL - return get_time_is_in_interval_from_series(series, - start_time=interval[0], - end_time=interval[1]) + return get_time_is_in_interval_from_series( + series, start_time=interval[0], end_time=interval[1] + ) def truncate_series(series: pd.Series, truncation_period: str) -> pd.Series: @@ -133,43 +134,43 @@ def truncate_series(series: pd.Series, truncation_period: str) -> pd.Series: SUPPORTED_DATETIME_ATTRS = [ - 'time', - 'hour', - 'month', - 'day', - 'dayofweek', - 'dayofyear', - 'weekday_name', + "time", + "hour", + "month", + "day", + "dayofweek", + "dayofyear", + "weekday_name", ] # Custom datetime featurizers such as daylight etc. CUSTOM_ATTRIBUTES = { - 'fractionalday': get_fractional_day_from_series, - 'fractionalhour': get_fractional_hour_from_series, - 'fractionalyear': get_fractional_year_from_series, - 'month0': get_zero_indexed_month_from_series, - 'is_holiday': get_is_holiday_from_series, - 'is_weekday': get_is_weekday_from_series, - 'is_in_interval': get_time_is_in_interval_from_series, - 'is_peak': get_is_peak_hour_from_series, - 'is_daytime': get_is_daytime_from_series, - 'is_morningpeak': get_is_morning_peak_from_series, - 'dt_truncated': truncate_series + "fractionalday": get_fractional_day_from_series, + "fractionalhour": get_fractional_hour_from_series, + "fractionalyear": get_fractional_year_from_series, + "month0": get_zero_indexed_month_from_series, + "is_holiday": get_is_holiday_from_series, + "is_weekday": get_is_weekday_from_series, + "is_in_interval": get_time_is_in_interval_from_series, + "is_peak": get_is_peak_hour_from_series, + "is_daytime": get_is_daytime_from_series, + "is_morningpeak": get_is_morning_peak_from_series, + "dt_truncated": truncate_series, } class PandasDateTimeFeaturizer(BaseEstimator, TransformerMixin, CallableMixin): """Featurize datetime column by adding specified attributes.""" - valid_attributes = ( - SUPPORTED_DATETIME_ATTRS + list(CUSTOM_ATTRIBUTES.keys()) + valid_attributes = SUPPORTED_DATETIME_ATTRS + list( + CUSTOM_ATTRIBUTES.keys() ) def __init__( self, column=ini.Columns.datetime, - attributes=['month0', 'dayofweek', 'fractionalday'], - kwargs=None + attributes=["month0", "dayofweek", "fractionalday"], + kwargs=None, ): self.column = column self.attributes = attributes @@ -185,7 +186,7 @@ def fit(self, df, y=None): def transform(self, df): if not isinstance(df, pd.DataFrame): - raise TypeError('Input must be a DataFrame.') + raise TypeError("Input must be a DataFrame.") df = df.copy() column = df[self.column] if isinstance(column, pd.DataFrame): # hierarchical index @@ -204,7 +205,7 @@ def transform(self, df): else: raise KeyError( f'Unknown attribute "{attr}"; ' - 'see `PandasDateTimeFeaturizer.valid_attributes`' + "see `PandasDateTimeFeaturizer.valid_attributes`" ) return df @@ -215,7 +216,7 @@ def required_columns(self): def transformed_columns(self, input_columns): input_columns = set(_as_list_of_str(input_columns)) if not self.required_columns <= input_columns: - raise ValueError(f'Required columns are {self.required_columns}') + raise ValueError(f"Required columns are {self.required_columns}") return input_columns | set(self.attributes_) @@ -227,12 +228,7 @@ class _BaseLagFeaturizer(BaseEstimator, TransformerMixin, CallableMixin): duplicate_agg: str = "raise" def __init__( - self, - datetime_column, - columns, - lags, - refit=True, - duplicate_agg="raise" + self, datetime_column, columns, lags, refit=True, duplicate_agg="raise" ): self.datetime_column = datetime_column self.columns = columns @@ -241,15 +237,16 @@ def __init__( self.duplicate_agg = duplicate_agg def fit(self, df, y=None, **fit_params): - if hasattr(self, 'df_') and not self.refit: + if hasattr(self, "df_") and not self.refit: return self columns = _as_list_of_str(self.columns) self.df_ = df.set_index(self.datetime_column)[columns] - if self.duplicate_agg == 'raise': + if self.duplicate_agg == "raise": if any(self.df_.index.duplicated()): raise ValueError( "Input dataframe contains duplicate entries " - "with the same %s", self.datetime_column + "with the same %s", + self.datetime_column, ) else: self.df_ = self.df_.groupby(level=0).agg(self.duplicate_agg) @@ -261,15 +258,18 @@ def _lag_df(self, lag): raise NotImplementedError def transform(self, df): - check_is_fitted(self, 'df_') + check_is_fitted(self, "df_") lags = _as_list_of_str(self.lags) df = df.copy().set_index(self.datetime_column) for lag in lags: lag_df = self._lag_df(lag).add_suffix(f"_{lag}") df = pd.merge( - df, lag_df, - how="left", left_index=True, right_index=True, - suffixes=("", "") + df, + lag_df, + how="left", + left_index=True, + right_index=True, + suffixes=("", ""), ) return df.reset_index() @@ -281,7 +281,7 @@ def required_columns(self): def transformed_columns(self, input_columns): input_columns = set(_as_list_of_str(input_columns)) if not self.required_columns <= input_columns: - raise ValueError(f'Required columns are {self.required_columns}') + raise ValueError(f"Required columns are {self.required_columns}") lags = _as_list_of_str(self.lags) columns = _as_list_of_str(self.columns) new_columns = [f"{col}_{lag}" for lag in lags for col in columns] @@ -362,14 +362,10 @@ def __init__( win_type: str = None, closed: str = None, refit: bool = True, - duplicate_agg: str = 'raise' + duplicate_agg: str = "raise", ): super().__init__( - datetime_column, - columns, - windows, - refit, - duplicate_agg + datetime_column, columns, windows, refit, duplicate_agg ) self.min_periods = min_periods self.center = center @@ -390,5 +386,5 @@ def _lag_df(self, lag): min_periods=self.min_periods, center=self.center, win_type=self.win_type, - closed=self.closed + closed=self.closed, ).mean()