diff --git a/CHANGELOG.md b/CHANGELOG.md index 94b7e56f..4e93eadf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ # Release Notes -## v0.9.3 UNRELEASED +## v0.10.0 UNRELEASED ### Added +- Added new `N-HiTS` network that has consistently beaten `N-BEATS` (#890) - Allow using [torchmetrics](https://torchmetrics.readthedocs.io/) as loss metrics (#776) - Enable fitting `EncoderNormalizer()` with limited data history using `max_length` argument (#782) - More flexible `MultiEmbedding()` with convenience `output_size` and `input_size` properties (#829) diff --git a/README.md b/README.md index 444d9a12..200ee3ea 100755 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ The documentation provides a [comparison of available models](https://pytorch-fo - [N-BEATS: Neural basis expansion analysis for interpretable time series forecasting](http://arxiv.org/abs/1905.10437) which has (if used as ensemble) outperformed all other methods including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting. +- [N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting](http://arxiv.org/abs/2201.12886) which supports covariates and has consistently beaten N-BEATS. It is also particularly well-suited for long-horizon forecasting. - [DeepAR: Probabilistic forecasting with autoregressive recurrent networks](https://www.sciencedirect.com/science/article/pii/S0169207019301888) which is the one of the most popular forecasting algorithms and is often used as a baseline - Simple standard networks for baselining: LSTM and GRU networks as well as a MLP on the decoder diff --git a/docs/source/models.rst b/docs/source/models.rst index 2cf82912..e0e8d197 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 + :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 @@ -85,6 +86,9 @@ multiple targets and even hetrogeneous targets where some are continuous variabl i.e. regression and classification at the same time. :py:class:`~pytorch_forecasting.models.deepar.DeepAR` can handle multiple targets but only works for regression tasks. +For long forecast horizon forecasts, :py:class:`~pytorch_forecasting.models.nhits.NHiTS` is an excellent choice +as it uses interpolation capabilities. + Supporting uncertainty ~~~~~~~~~~~~~~~~~~~~~~~ @@ -123,7 +127,8 @@ the lifetime of a model. The :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer` is a rather large model but might benefit from being trained with. -For example, :py:class:`~pytorch_forecasting.models.nbeats.NBeats` is an efficient model. +For example, :py:class:`~pytorch_forecasting.models.nbeats.NBeats` or :py:class:`~pytorch_forecasting.models.nhits.NHiTS` are +efficient models. Autoregressive models such as :py:class:`~pytorch_forecasting.models.deepar.DeepAR` might be quick to train but might be slow at inference time (in case of :py:class:`~pytorch_forecasting.models.deepar.DeepAR` this is driven by sampling results probabilistically multiple times, effectively increasing the computational burden linearly with the diff --git a/docs/source/tutorials/ar.ipynb b/docs/source/tutorials/ar.ipynb index 434416fd..9625c433 100644 --- a/docs/source/tutorials/ar.ipynb +++ b/docs/source/tutorials/ar.ipynb @@ -334,7 +334,7 @@ ], "source": [ "# find optimal learning rate\n", - "res = trainer.tuner.lr_find(net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)\n", + "res = trainer.tuner.lr_find(net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)\n", "print(f\"suggested learning rate: {res.suggestion()}\")\n", "fig = res.plot(show=True, suggest=True)\n", "fig.show()\n", @@ -617,7 +617,7 @@ "\n", "trainer.fit(\n", " net,\n", - " train_dataloader=train_dataloader,\n", + " train_dataloaders=train_dataloader,\n", " val_dataloaders=val_dataloader,\n", ")" ] diff --git a/docs/source/tutorials/building.ipynb b/docs/source/tutorials/building.ipynb index 06506455..1e35a35e 100644 --- a/docs/source/tutorials/building.ipynb +++ b/docs/source/tutorials/building.ipynb @@ -3803,7 +3803,7 @@ "\n", "model = FullyConnectedForDistributionLossModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2, log_interval=1)\n", "trainer = Trainer(fast_dev_run=True)\n", - "trainer.fit(model, train_dataloader=dataloader, val_dataloaders=dataloader)" + "trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)" ] } ], diff --git a/docs/source/tutorials/stallion.ipynb b/docs/source/tutorials/stallion.ipynb index 5d4a9aa0..9df5ac43 100644 --- a/docs/source/tutorials/stallion.ipynb +++ b/docs/source/tutorials/stallion.ipynb @@ -1049,7 +1049,7 @@ "# find optimal learning rate\n", "res = trainer.tuner.lr_find(\n", " tft,\n", - " train_dataloader=train_dataloader,\n", + " train_dataloaders=train_dataloader,\n", " val_dataloaders=val_dataloader,\n", " max_lr=10.0,\n", " min_lr=1e-6,\n", @@ -1577,7 +1577,7 @@ "# fit network\n", "trainer.fit(\n", " tft,\n", - " train_dataloader=train_dataloader,\n", + " train_dataloaders=train_dataloader,\n", " val_dataloaders=val_dataloader,\n", ")" ] diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index 31371736..3ee46e66 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -40,6 +40,7 @@ DeepAR, MultiEmbedding, NBeats, + NHiTS, RecurrentNetwork, TemporalFusionTransformer, get_rnn, @@ -66,6 +67,7 @@ "MultiNormalizer", "TemporalFusionTransformer", "NBeats", + "NHiTS", "Baseline", "DeepAR", "BaseModel", diff --git a/pytorch_forecasting/metrics.py b/pytorch_forecasting/metrics.py index 72ac211e..371c3c83 100644 --- a/pytorch_forecasting/metrics.py +++ b/pytorch_forecasting/metrics.py @@ -264,19 +264,30 @@ def __len__(self) -> int: """ return len(self.metrics) - def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor): + def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs): """ Update composite metric Args: y_pred: network output y_actual: actual values + **kwargs: arguments to update function Returns: torch.Tensor: metric value on which backpropagation can be applied """ for idx, metric in enumerate(self.metrics): - metric.update(y_pred[idx], (y_actual[0][idx], y_actual[1])) + try: + metric.update( + y_pred[idx], + (y_actual[0][idx], y_actual[1]), + **{ + name: value[idx] if isinstance(value, (list, tuple)) else value + for name, value in kwargs.items() + }, + ) + except TypeError: # silently update without kwargs if not supported + metric.update(y_pred[idx], (y_actual[0][idx], y_actual[1])) def compute(self) -> torch.Tensor: """ @@ -949,7 +960,7 @@ def update( self._update_losses_and_lengths(losses, lengths) def loss(self, y_pred, target, scaling): - return (y_pred - target).abs() / scaling.unsqueeze(-1) + return (self.to_prediction(y_pred) - target).abs() / scaling.unsqueeze(-1) def calculate_scaling(self, target, lengths, encoder_target, encoder_lengths): # calcualte mean(abs(diff(targets))) diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 496a3691..adbe7933 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -11,12 +11,14 @@ from pytorch_forecasting.models.deepar import DeepAR from pytorch_forecasting.models.mlp import DecoderMLP from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer __all__ = [ "NBeats", + "NHiTS", "TemporalFusionTransformer", "RecurrentNetwork", "DeepAR", diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 24fbc020..79126c4c 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -36,9 +36,11 @@ QuantileLoss, convert_torchmetric_to_pytorch_forecasting_metric, ) +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding from pytorch_forecasting.optim import Ranger from pytorch_forecasting.utils import ( OutputMixIn, + TupleOutputMixIn, apply_to_list, create_mask, get_embedding_size, @@ -131,7 +133,7 @@ def _concatenate_output( } -class BaseModel(LightningModule): +class BaseModel(LightningModule, TupleOutputMixIn): """ BaseModel from which new timeseries models should inherit from. The ``hparams`` of the created object will default to the parameters indicated in :py:meth:`~__init__`. @@ -192,6 +194,7 @@ def __init__( loss: Metric = SMAPE(), logging_metrics: nn.ModuleList = nn.ModuleList([]), reduce_on_plateau_patience: int = 1000, + reduce_on_plateau_reduction: float = 2.0, reduce_on_plateau_min_lr: float = 1e-5, weight_decay: float = 0.0, optimizer_params: Dict[str, Any] = None, @@ -215,6 +218,7 @@ def __init__( Defaults to []. reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10. Defaults to 1000 + reduce_on_plateau_reduction (float): reduction in learning rate when encountering plateau. Defaults to 2.0. reduce_on_plateau_min_lr (float): minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5 weight_decay (float): weight decay. Defaults to 0.0. @@ -513,7 +517,7 @@ def step( # multiply monotinicity loss by large number to ensure relevance and take to the power of 2 # for smoothness of loss function monotinicity_loss = 10 * torch.pow(monotinicity_loss, 2) - if isinstance(self.loss, MASE): + if isinstance(self.loss, (MASE, MultiLoss)): loss = self.loss( prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"] ) @@ -526,10 +530,9 @@ def step( # calculate loss prediction = out["prediction"] - if isinstance(self.loss, MASE): - loss = self.loss( - prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"] - ) + if isinstance(self.loss, (MASE, MultiLoss)): + mase_kwargs = dict(encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]) + loss = self.loss(prediction, y, **mase_kwargs) else: loss = self.loss(prediction, y) @@ -595,27 +598,6 @@ def log_metrics( batch_size=len(x["decoder_target"]), ) - def to_network_output(self, **results): - """ - Convert output into a named (and immuatable) tuple. - - This allows tracing the modules as graphs and prevents modifying the output. - - Returns: - named tuple - """ - if hasattr(self, "_output_class"): - Output = self._output_class - else: - OutputTuple = namedtuple("output", results) - - class Output(OutputMixIn, OutputTuple): - pass - - self._output_class = Output - - return self._output_class(**results) - def forward( self, x: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: @@ -956,7 +938,7 @@ def configure_optimizers(self): "scheduler": ReduceLROnPlateau( optimizer, mode="min", - factor=0.2, + factor=1.0 / self.hparams.reduce_on_plateau_reduction, patience=self.hparams.reduce_on_plateau_patience, cooldown=self.hparams.reduce_on_plateau_patience, min_lr=self.hparams.reduce_on_plateau_min_lr, @@ -1347,6 +1329,25 @@ class BaseModelWithCovariates(BaseModel): as bag of embeddings """ + @property + def target_positions(self) -> torch.LongTensor: + """ + Positions of target variable(s) in covariates. + + Returns: + torch.LongTensor: tensor of positions. + """ + # todo: expand for categorical targets + if "target" in self.hparams: + target = self.hparams.target + else: + target = self.dataset_parameters["target"] + return torch.tensor( + [self.hparams.x_reals.index(name) for name in to_list(target)], + device=self.device, + dtype=torch.long, + ) + @property def reals(self) -> List[str]: """List of all continuous variables in model""" @@ -1454,6 +1455,47 @@ def from_dataset( new_kwargs.update(kwargs) return super().from_dataset(dataset, **new_kwargs) + def extract_features( + self, + x, + embeddings: MultiEmbedding = None, + period: str = "all", + ) -> torch.Tensor: + """ + Extract features + + Args: + x (Dict[str, torch.Tensor]): input from the dataloader + embeddings (MultiEmbedding): embeddings for categorical variables + period (str, optional): One of "encoder", "decoder" or "all". Defaults to "all". + + Returns: + torch.Tensor: tensor with selected variables + """ + # select period + if period == "encoder": + x_cat = x["encoder_cat"] + x_cont = x["encoder_cont"] + elif period == "decoder": + x_cat = x["decoder_cat"] + x_cont = x["decoder_cont"] + elif period == "all": + x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension + x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension + else: + raise ValueError(f"Unknown type: {type}") + + # create dictionary of encoded vectors + input_vectors = embeddings(x_cat) + input_vectors.update( + { + name: x_cont[..., idx].unsqueeze(-1) + for idx, name in enumerate(self.hparams.x_reals) + if name in self.reals + } + ) + return input_vectors + def calculate_prediction_actual_by_variable( self, x: Dict[str, torch.Tensor], @@ -1983,21 +2025,6 @@ class AutoRegressiveBaseModelWithCovariates(BaseModelWithCovariates, AutoRegress as bag of embeddings """ - @property - def target_positions(self) -> torch.LongTensor: - """ - Positions of target variable(s) in covariates. - - Returns: - torch.LongTensor: tensor of positions. - """ - # todo: expand for categorical targets - return torch.tensor( - [self.hparams.x_reals.index(name) for name in to_list(self.hparams.target)], - device=self.device, - dtype=torch.long, - ) - @property def lagged_target_positions(self) -> Dict[int, torch.LongTensor]: """ diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index d23f5fbb..62b35495 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -48,6 +48,9 @@ def __init__( the most important benchmark for univariate time series forecasting. + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform + N-BEATS. + Args: stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings of length 1 or ‘num_stacks’. Default and recommended value diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py new file mode 100644 index 00000000..e201518e --- /dev/null +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -0,0 +1,492 @@ +""" +N-HiTS model for timeseries forecasting with covariates. +""" +from copy import copy +from typing import Dict, List, Optional, Tuple, Union + +from matplotlib import pyplot as plt +import numpy as np +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding +from pytorch_forecasting.utils import create_mask + + +class NHiTS(BaseModelWithCovariates): + def __init__( + self, + output_size: Union[int, List[int]] = 1, + static_categoricals: List[str] = [], + static_reals: List[str] = [], + time_varying_categoricals_encoder: List[str] = [], + time_varying_categoricals_decoder: List[str] = [], + categorical_groups: Dict[str, List[str]] = {}, + time_varying_reals_encoder: List[str] = [], + time_varying_reals_decoder: List[str] = [], + embedding_sizes: Dict[str, Tuple[int, int]] = {}, + embedding_paddings: List[str] = [], + embedding_labels: Dict[str, np.ndarray] = {}, + x_reals: List[str] = [], + x_categoricals: List[str] = [], + context_length: int = 1, + prediction_length: int = 1, + static_hidden_size: Optional[int] = None, + shared_weights: bool = True, + activation: str = "ReLU", + initialization: str = "lecun_normal", + n_blocks: List[int] = [1, 1, 1], + n_layers: Union[int, List[int]] = 2, + hidden_size: int = 512, + pooling_sizes: Optional[List[int]] = None, + downsample_frequencies: Optional[List[int]] = None, + pooling_mode: str = "max", + interpolation_mode: str = "linear", + batch_normalization: bool = False, + dropout: float = 0.0, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Initialize N-HiTS Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting `_. + The network has shown to increase accuracy by ~25% against + :py:class:`~pytorch_forecasting.models.nbeats.NBeats` and also supports covariates. + + Args: + hidden_size (int): size of hidden layers and can range from 8 to 1024 - use 32-128 if no + covariates are employed. Defaults to 512. + static_hidden_size (Optional[int], optional): size of hidden layers for static variables. + Defaults to hidden_size. + loss: loss to optimize. Defaults to MASE(). + shared_weights (bool, optional): if True, weights of blocks are shared in each stack. Defaults to True. + initialization (str, optional): Initialization method. One of ['orthogonal', 'he_uniform', 'glorot_uniform', + 'glorot_normal', 'lecun_normal']. Defaults to "lecun_normal". + n_blocks (List[int], optional): list of blocks used in each stack (i.e. length of stacks). + Defaults to [1, 1, 1]. + n_layers (Union[int, List[int]], optional): Number of layers per block or list of number of + layers used by blocks in each stack (i.e. length of stacks). Defaults to 2. + pooling_sizes (Optional[List[int]], optional): List of pooling sizes for input for each stack, + i.e. higher means more smoothing of input. Using an ordering of higher to lower in the list + improves results. + Defaults to a heuristic. + pooling_mode (str, optional): Pooling mode for summarizing input. One of ['max','average']. + Defaults to "max". + downsample_frequencies (Optional[List[int]], optional): Downsample multiplier of output for each stack, i.e. + higher means more interpolation at forecast time is required. Should be equal or higher + than pooling_sizes but smaller equal prediction_length. + Defaults to a heuristic to match pooling_sizes. + interpolation_mode (str, optional): Interpolation mode for forecasting. One of ['linear', 'nearest', + 'cubic-x'] where 'x' is replaced by a batch size for the interpolation. Defaults to "linear". + batch_normalization (bool, optional): Whether carry out batch normalization. Defaults to False. + dropout (float, optional): dropout rate for hidden layers. Defaults to 0.0. + activation (str, optional): activation function. One of ['ReLU', 'Softplus', 'Tanh', 'SELU', + 'LeakyReLU', 'PReLU', 'Sigmoid']. Defaults to "ReLU". + output_size: number of outputs (typically number of quantiles for QuantileLoss and one target or list + of output sizes but currently only point-forecasts allowed). Set automatically. + static_categoricals: names of static categorical variables + static_reals: names of static continuous variables + time_varying_categoricals_encoder: names of categorical variables for encoder + time_varying_categoricals_decoder: names of categorical variables for decoder + time_varying_reals_encoder: names of continuous variables for encoder + time_varying_reals_decoder: names of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical + embedding size) + hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection + (fallback to hidden_continuous_size if index is not in dictionary) + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + learning_rate: learning rate + log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 + , will log multiple entries per batch. Defaults to -1. + log_val_interval: frequency with which to log validation set metrics, defaults to log_interval + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. + A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and + forecast lengths). Defaults to 0.0, i.e. no weight. + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + if activation == "SELU": + self.hparams.initialization = "lecun_normal" + + # provide default downsampling sizes + n_stacks = len(n_blocks) + if pooling_sizes is None: + pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(prediction_length / 2), n_stacks))) + pooling_sizes = [int(x) for x in pooling_sizes[::-1]] + if downsample_frequencies is None: + downsample_frequencies = [min(prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes] + + # set static hidden size + if static_hidden_size is None: + static_hidden_size = hidden_size + + # set layers + if isinstance(n_layers, int): + n_layers = [n_layers] * n_stacks + + self.save_hyperparameters() + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.embeddings = MultiEmbedding( + embedding_sizes=self.hparams.embedding_sizes, + categorical_groups=self.hparams.categorical_groups, + embedding_paddings=self.hparams.embedding_paddings, + x_categoricals=self.hparams.x_categoricals, + ) + if isinstance(self.hparams.output_size, int): + output_size = self.hparams.output_size + else: + output_size = sum(self.hparams.output_size) + + self.model = NHiTSModule( + context_length=self.hparams.context_length, + prediction_length=self.hparams.prediction_length, + output_size=output_size, + static_size=self.static_size, + covariate_size=self.covariate_size, + static_hidden_size=self.hparams.static_hidden_size, + n_blocks=self.hparams.n_blocks, + n_layers=self.hparams.n_layers, + hidden_size=self.n_stacks * [2 * [self.hparams.hidden_size]], + pooling_sizes=self.hparams.pooling_sizes, + downsample_frequencies=self.hparams.downsample_frequencies, + pooling_mode=self.hparams.pooling_mode, + interpolation_mode=self.hparams.interpolation_mode, + dropout=self.hparams.dropout, + activation=self.hparams.activation, + initialization=self.hparams.initialization, + batch_normalization=self.hparams.batch_normalization, + shared_weights=self.hparams.shared_weights, + ) + + @property + def covariate_size(self) -> int: + """Covariate size. + + Returns: + int: size of time-dependent covariates + """ + return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum( + self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder + ) + + @property + def static_size(self) -> int: + """Static covariate size. + + Returns: + int: size of static covariates + """ + return len(self.hparams.static_reals) + sum( + self.embeddings.output_size[name] for name in self.hparams.static_categoricals + ) + + @property + def n_stacks(self) -> int: + """Number of stacks. + + Returns: + int: number of stacks. + """ + return len(self.hparams.n_blocks) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + # covariates + if self.covariate_size > 0: + encoder_features = self.extract_features(x, self.embeddings, period="encoder") + encoder_x_t = torch.concat( + [encoder_features[name] for name in self.encoder_variables if name not in self.target_names], + dim=2, + ) + decoder_features = self.extract_features(x, self.embeddings, period="decoder") + decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2) + else: + encoder_x_t = None + decoder_x_t = None + + # statics + if self.static_size > 0: + x_s = torch.concat([encoder_features[name][:, 0] for name in self.static_variables], dim=1) + else: + x_s = None + + # target + encoder_y = x["encoder_cont"][..., self.target_positions] + encoder_mask = create_mask(x["encoder_lengths"].max(), x["encoder_lengths"], inverse=True) + + # run model + forecast, backcast, block_forecasts, block_backcasts = self.model( + encoder_y, encoder_mask, encoder_x_t, decoder_x_t, x_s + ) + + # create output + block_predictions = torch.cat([block_backcasts.detach(), block_forecasts.detach()], dim=1) + + if forecast.size(2) > 1: # multi-output + n_outputs = forecast.size(2) + forecast = [forecast[:, :, i] for i in range(n_outputs)] + backcast = [encoder_y[:, :, i] - backcast[:, :, i] for i in range(n_outputs)] + + n_blocks = block_predictions.size(3) + block_predictions = [block_predictions[:, :, i] for i in range(n_outputs)] + block_predictions = tuple( + self.transform_output([b[..., block] for b in block_predictions], target_scale=x["target_scale"]) + for block in range(n_blocks) + ) + else: + block_predictions = tuple( + self.transform_output(block_predictions[..., i], target_scale=x["target_scale"]) + for i in range(block_predictions[0].size(-1)) + ) + backcast = encoder_y - backcast + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output(backcast, target_scale=x["target_scale"]), + block_predictions=block_predictions, + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ + # validate arguments + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert ( + dataset.min_encoder_length == dataset.max_encoder_length + ), "only fixed encoder length is allowed, but min_encoder_length != max_encoder_length" + + assert ( + dataset.max_prediction_length == dataset.min_prediction_length + ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" + + assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" + assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" + + new_kwargs = copy(kwargs) + new_kwargs.update( + {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} + ) + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MASE())) + + assert (isinstance(new_kwargs["output_size"], int) and new_kwargs["output_size"] == 1) or all( + o == 1 for o in new_kwargs["output_size"] + ), "output sizes can only be of size 1, i.e. point forecasts" + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if self.hparams.backcast_loss_ratio > 0: # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, (MultiLoss, MASE)): + backcast_loss = ( + self.loss( + backcast, + (x["encoder_target"], None), + encoder_target=x["decoder_target"], + encoder_lengths=x["decoder_lengths"], + ) + * backcast_weight + ) + else: + backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + # log interpretation + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + ) -> plt.Figure: + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into different block predictions which capture different frequencies. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. + Defaults to None. + + Returns: + plt.Figure: matplotlib figure + """ + if isinstance(x["encoder_target"], (tuple, list)): # multi-target + figs = [] + for i in range(len(self.target_names)): + if ax is not None: + ax_i = ax[i] + else: + ax_i = None + + figs.append( + self.plot_interpretation( + dict(encoder_target=x["encoder_target"][i], decoder_target=x["decoder_target"][i]), + dict( + backcast=output["backcast"][i], + prediction=output["prediction"][i], + block_predictions=output["block_predictions"][i], + ), + idx=idx, + ax=ax_i, + ) + ) + return figs + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange(-self.hparams.context_length, self.hparams.prediction_length) + + # plot target vs prediction + ax[0].plot(time, torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]).detach().cpu(), label="target") + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + + for pooling_size, block_prediction in zip(self.hparams.pooling_sizes, output["block_predictions"][1:]): + ax[1].plot( + time, + block_prediction[idx].detach().cpu(), + label=f"Pooling size: {pooling_size}", + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + if isinstance(fig, (list, tuple)): + for idx, f in enumerate(fig): + self.logger.experiment.add_figure( + f"{self.target_names[idx]} {name}", + f, + global_step=self.global_step, + ) + else: + self.logger.experiment.add_figure( + name, + fig, + global_step=self.global_step, + ) diff --git a/pytorch_forecasting/models/nhits/sub_modules.py b/pytorch_forecasting/models/nhits/sub_modules.py new file mode 100644 index 00000000..06c2505b --- /dev/null +++ b/pytorch_forecasting/models/nhits/sub_modules.py @@ -0,0 +1,348 @@ +from functools import partial +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class StaticFeaturesEncoder(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + layers = [nn.Dropout(p=0.5), nn.Linear(in_features=in_features, out_features=out_features), nn.ReLU()] + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + x = self.encoder(x) + return x + + +class IdentityBasis(nn.Module): + def __init__(self, backcast_size: int, forecast_size: int, interpolation_mode: str): + super().__init__() + assert (interpolation_mode in ["linear", "nearest"]) or ("cubic" in interpolation_mode) + self.forecast_size = forecast_size + self.backcast_size = backcast_size + self.interpolation_mode = interpolation_mode + + def forward( + self, theta: torch.Tensor, encoder_x_t: torch.Tensor, decoder_x_t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + + backcast = theta[:, : self.backcast_size] + knots = theta[:, self.backcast_size :] + + if self.interpolation_mode == "nearest": + knots = knots[:, None, :] + forecast = F.interpolate(knots, size=self.forecast_size, mode=self.interpolation_mode) + forecast = forecast[:, 0, :] + elif self.interpolation_mode == "linear": + knots = knots[:, None, :] + forecast = F.interpolate( + knots, size=self.forecast_size, mode=self.interpolation_mode + ) # , align_corners=True) + forecast = forecast[:, 0, :] + elif "cubic" in self.interpolation_mode: + batch_size = int(self.interpolation_mode.split("-")[-1]) + knots = knots[:, None, None, :] + forecast = torch.zeros((len(knots), self.forecast_size)).to(knots.device) + n_batches = int(np.ceil(len(knots) / batch_size)) + for i in range(n_batches): + forecast_i = F.interpolate( + knots[i * batch_size : (i + 1) * batch_size], size=self.forecast_size, mode="bicubic" + ) # , align_corners=True) + forecast[i * batch_size : (i + 1) * batch_size] += forecast_i[:, 0, 0, :] + + return backcast, forecast + + +def init_weights(module, initialization): + if type(module) == torch.nn.Linear: + if initialization == "orthogonal": + torch.nn.init.orthogonal_(module.weight) + elif initialization == "he_uniform": + torch.nn.init.kaiming_uniform_(module.weight) + elif initialization == "he_normal": + torch.nn.init.kaiming_normal_(module.weight) + elif initialization == "glorot_uniform": + torch.nn.init.xavier_uniform_(module.weight) + elif initialization == "glorot_normal": + torch.nn.init.xavier_normal_(module.weight) + elif initialization == "lecun_normal": + pass # torch.nn.init.normal_(module.weight, 0.0, std=1/np.sqrt(module.weight.numel())) + else: + assert 1 < 0, f"Initialization {initialization} not found" + + +ACTIVATIONS = ["ReLU", "Softplus", "Tanh", "SELU", "LeakyReLU", "PReLU", "Sigmoid"] + + +class NHiTSBlock(nn.Module): + """ + N-HiTS block which takes a basis function as an argument. + """ + + def __init__( + self, + context_length: int, + prediction_length: int, + output_size: int, + covariate_size: int, + static_size: int, + static_hidden_size: int, + n_theta: int, + hidden_size: List[int], + pooling_sizes: int, + pooling_mode: str, + basis: nn.Module, + n_layers: int, + batch_normalization: bool, + dropout: float, + activation: str, + ): + super().__init__() + + assert pooling_mode in ["max", "average"] + + self.context_length_pooled = int(np.ceil(context_length / pooling_sizes)) + + if static_size == 0: + static_hidden_size = 0 + + self.context_length = context_length + self.output_size = output_size + self.n_theta = n_theta + self.prediction_length = prediction_length + self.static_size = static_size + self.static_hidden_size = static_hidden_size + self.covariate_size = covariate_size + self.pooling_sizes = pooling_sizes + self.batch_normalization = batch_normalization + self.dropout = dropout + + self.hidden_size = [ + self.context_length_pooled * self.output_size + + (self.context_length + self.prediction_length) * self.covariate_size + + self.static_hidden_size + ] + hidden_size + + assert activation in ACTIVATIONS, f"{activation} is not in {ACTIVATIONS}" + activ = getattr(nn, activation)() + + if pooling_mode == "max": + self.pooling_layer = nn.MaxPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True) + elif pooling_mode == "average": + self.pooling_layer = nn.AvgPool1d(kernel_size=self.pooling_sizes, stride=self.pooling_sizes, ceil_mode=True) + + hidden_layers = [] + for i in range(n_layers): + hidden_layers.append(nn.Linear(in_features=self.hidden_size[i], out_features=self.hidden_size[i + 1])) + hidden_layers.append(activ) + + if self.batch_normalization: + hidden_layers.append(nn.BatchNorm1d(num_features=self.hidden_size[i + 1])) + + if self.dropout > 0: + hidden_layers.append(nn.Dropout(p=self.dropout)) + + output_layer = [nn.Linear(in_features=self.hidden_size[-1], out_features=n_theta * output_size)] + layers = hidden_layers + output_layer + + # static_size is computed with data, static_hidden_size is provided by user, if 0 no statics are used + if (self.static_size > 0) and (self.static_hidden_size > 0): + self.static_encoder = StaticFeaturesEncoder(in_features=static_size, out_features=static_hidden_size) + self.layers = nn.Sequential(*layers) + self.basis = basis + + def forward( + self, encoder_y: torch.Tensor, encoder_x_t: torch.Tensor, decoder_x_t: torch.Tensor, x_s: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = len(encoder_y) + + encoder_y = encoder_y.transpose(1, 2) + # Pooling layer to downsample input + encoder_y = self.pooling_layer(encoder_y) + encoder_y = encoder_y.transpose(1, 2).reshape(batch_size, -1) + + if self.covariate_size > 0: + encoder_y = torch.cat( + ( + encoder_y, + encoder_x_t.reshape(batch_size, -1), + decoder_x_t.reshape(batch_size, -1), + ), + 1, + ) + + # Static exogenous + if (self.static_size > 0) and (self.static_hidden_size > 0): + x_s = self.static_encoder(x_s) + encoder_y = torch.cat((encoder_y, x_s), 1) + + # Compute local projection weights and projection + theta = self.layers(encoder_y).reshape(-1, self.n_theta) + backcast, forecast = self.basis(theta, encoder_x_t, decoder_x_t) + backcast = backcast.reshape(-1, self.output_size, self.context_length).transpose(1, 2) + forecast = forecast.reshape(-1, self.output_size, self.prediction_length).transpose(1, 2) + + return backcast, forecast + + +class NHiTS(nn.Module): + """ + N-HiTS Model. + """ + + def __init__( + self, + context_length, + prediction_length, + output_size: int, + static_size, + covariate_size, + static_hidden_size, + n_blocks: list, + n_layers: list, + hidden_size: list, + pooling_sizes: list, + downsample_frequencies: list, + pooling_mode, + interpolation_mode, + dropout, + activation, + initialization, + batch_normalization, + shared_weights, + ): + super().__init__() + + self.prediction_length = prediction_length + self.context_length = context_length + + blocks = self.create_stack( + n_blocks=n_blocks, + context_length=context_length, + prediction_length=prediction_length, + output_size=output_size, + covariate_size=covariate_size, + static_size=static_size, + static_hidden_size=static_hidden_size, + n_layers=n_layers, + hidden_size=hidden_size, + pooling_sizes=pooling_sizes, + downsample_frequencies=downsample_frequencies, + pooling_mode=pooling_mode, + interpolation_mode=interpolation_mode, + batch_normalization=batch_normalization, + dropout=dropout, + activation=activation, + shared_weights=shared_weights, + initialization=initialization, + ) + self.blocks = torch.nn.ModuleList(blocks) + + def create_stack( + self, + n_blocks, + context_length, + prediction_length, + output_size, + covariate_size, + static_size, + static_hidden_size, + n_layers, + hidden_size, + pooling_sizes, + downsample_frequencies, + pooling_mode, + interpolation_mode, + batch_normalization, + dropout, + activation, + shared_weights, + initialization, + ): + + block_list = [] + for i in range(len(n_blocks)): + for block_id in range(n_blocks[i]): + + # Batch norm only on first block + if (len(block_list) == 0) and (batch_normalization): + batch_normalization_block = True + else: + batch_normalization_block = False + + # Shared weights + if shared_weights and block_id > 0: + nbeats_block = block_list[-1] + else: + n_theta = context_length + max(prediction_length // downsample_frequencies[i], 1) + basis = IdentityBasis( + backcast_size=context_length, + forecast_size=prediction_length, + interpolation_mode=interpolation_mode, + ) + + nbeats_block = NHiTSBlock( + context_length=context_length, + prediction_length=prediction_length, + output_size=output_size, + covariate_size=covariate_size, + static_size=static_size, + static_hidden_size=static_hidden_size, + n_theta=n_theta, + hidden_size=hidden_size[i], + pooling_sizes=pooling_sizes[i], + pooling_mode=pooling_mode, + basis=basis, + n_layers=n_layers[i], + batch_normalization=batch_normalization_block, + dropout=dropout, + activation=activation, + ) + + # Select type of evaluation and apply it to all layers of block + init_function = partial(init_weights, initialization=initialization) + nbeats_block.layers.apply(init_function) + block_list.append(nbeats_block) + return block_list + + def forward( + self, + encoder_y, + encoder_mask, + encoder_x_t, + decoder_x_t, + x_s, + ): + + residuals = ( + encoder_y # .flip(dims=(1,)) # todo: check if flip is required or should be rather replaced by scatter + ) + # encoder_x_t = encoder_x_t.flip(dims=(-1,)) + # encoder_mask = encoder_mask.flip(dims=(-1,)) + encoder_mask = encoder_mask.unsqueeze(-1) + + level = encoder_y[:, -1:].repeat(1, self.prediction_length, 1) # Level with Naive1 + block_forecasts = [level] + block_backcasts = [encoder_y[:, -1:].repeat(1, self.context_length, 1)] + + forecast = level + for block in self.blocks: + block_backcast, block_forecast = block( + encoder_y=residuals, encoder_x_t=encoder_x_t, decoder_x_t=decoder_x_t, x_s=x_s + ) + residuals = (residuals - block_backcast) * encoder_mask + + forecast = forecast + block_forecast + block_forecasts.append(block_forecast) + block_backcasts.append(block_backcast) + + # (n_batch, n_t, n_outputs, n_blocks) + block_forecasts = torch.stack(block_forecasts, dim=-1) + block_backcasts = torch.stack(block_backcasts, dim=-1) + backcast = residuals + + return forecast, backcast, block_forecasts, block_backcasts diff --git a/pytorch_forecasting/models/nn/__init__.py b/pytorch_forecasting/models/nn/__init__.py index 5a79a9ce..7a310737 100644 --- a/pytorch_forecasting/models/nn/__init__.py +++ b/pytorch_forecasting/models/nn/__init__.py @@ -1,4 +1,11 @@ +from re import S +from typing import Dict + +import torch +from torch import embedding, nn + from pytorch_forecasting.models.nn.embeddings import MultiEmbedding from pytorch_forecasting.models.nn.rnn import GRU, LSTM, HiddenState, get_rnn +from pytorch_forecasting.utils import TupleOutputMixIn -__all__ = ["MultiEmbedding", "get_rnn", "LSTM", "GRU", "HiddenState"] +__all__ = ["MultiEmbedding", "get_rnn", "LSTM", "GRU", "HiddenState", "TupleOutputMixIn"] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 96adfe82..dff3e87d 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -37,8 +37,8 @@ def on_validation_end(self, trainer, pl_module): def optimize_hyperparameters( - train_dataloader: DataLoader, - val_dataloader: DataLoader, + train_dataloaders: DataLoader, + val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, @@ -64,8 +64,8 @@ def optimize_hyperparameters( the PyTorch Lightning learning rate finder. Args: - train_dataloader (DataLoader): dataloader for training model - val_dataloader (DataLoader): dataloader for validating model + train_dataloaders (DataLoader): dataloader for training model + val_dataloaders (DataLoader): dataloader for validating model model_path (str): folder to which model checkpoints are saved max_epochs (int, optional): Maximum number of epochs to run training. Defaults to 20. n_trials (int, optional): Number of hyperparameter trials to run. Defaults to 100. @@ -101,8 +101,8 @@ def optimize_hyperparameters( Returns: optuna.Study: optuna study results """ - assert isinstance(train_dataloader.dataset, TimeSeriesDataSet) and isinstance( - val_dataloader.dataset, TimeSeriesDataSet + assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance( + val_dataloaders.dataset, TimeSeriesDataSet ), "dataloaders must be built from timeseriesdataset" logging_level = { @@ -155,7 +155,7 @@ def objective(trial: optuna.Trial) -> float: hidden_size = trial.suggest_int("hidden_size", *hidden_size_range, log=True) kwargs["loss"] = copy.deepcopy(loss) model = TemporalFusionTransformer.from_dataset( - train_dataloader.dataset, + train_dataloaders.dataset, dropout=trial.suggest_uniform("dropout", *dropout_range), hidden_size=hidden_size, hidden_continuous_size=trial.suggest_int( @@ -179,8 +179,8 @@ def objective(trial: optuna.Trial) -> float: ) res = lr_trainer.tuner.lr_find( model, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, early_stop_threshold=10000, min_lr=learning_rate_range[0], num_training=100, @@ -206,7 +206,7 @@ def objective(trial: optuna.Trial) -> float: model.hparams.learning_rate = trial.suggest_loguniform("learning_rate", *learning_rate_range) # fit - trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + trainer.fit(model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) # report result return metrics_callback.metrics[-1]["val_loss"].item() diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils.py index 9c600010..407aa428 100644 --- a/pytorch_forecasting/utils.py +++ b/pytorch_forecasting/utils.py @@ -1,6 +1,7 @@ """ Helper functions for PyTorch forecasting """ +from collections import namedtuple from contextlib import redirect_stdout import os from typing import Any, Callable, Dict, List, Tuple, Union @@ -338,6 +339,31 @@ def keys(self): return self._fields +class TupleOutputMixIn: + """MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``.""" + + def to_network_output(self, **results): + """ + Convert output into a named (and immuatable) tuple. + + This allows tracing the modules as graphs and prevents modifying the output. + + Returns: + named tuple + """ + if hasattr(self, "_output_class"): + Output = self._output_class + else: + OutputTuple = namedtuple("output", results) + + class Output(OutputMixIn, OutputTuple): + pass + + self._output_class = Output + + return self._output_class(**results) + + def move_to_device( x: Union[ Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]], diff --git a/tests/test_models/conftest.py b/tests/test_models/conftest.py index 05255122..7a56d306 100644 --- a/tests/test_models/conftest.py +++ b/tests/test_models/conftest.py @@ -137,11 +137,21 @@ def dataloaders_with_covariates(data_with_covariates): time_varying_known_reals=["discount"], time_varying_unknown_reals=["target"], static_categoricals=["agency"], - add_relative_time_idx=True, + add_relative_time_idx=False, target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), ) +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + @pytest.fixture(scope="session") def dataloaders_fixed_window_without_covariates(): data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index b85c632e..7503a8f1 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -75,3 +75,10 @@ def model(dataloaders_fixed_window_without_covariates): def test_pickle(model): pkl = pickle.dumps(model) pickle.loads(pkl) + + +def test_interpretation(model, dataloaders_fixed_window_without_covariates): + raw_predictions, x = model.predict( + dataloaders_fixed_window_without_covariates["val"], mode="raw", return_x=True, fast_dev_run=True + ) + model.plot_interpretation(x, raw_predictions, idx=0) diff --git a/tests/test_models/test_nhits.py b/tests/test_models/test_nhits.py new file mode 100644 index 00000000..1cf636f1 --- /dev/null +++ b/tests/test_models/test_nhits.py @@ -0,0 +1,103 @@ +import pickle +import shutil + +import pytest +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.loggers import TensorBoardLogger + +from pytorch_forecasting.models import NHiTS + + +def _integration(dataloader, tmp_path, gpus): + train_dataloader = dataloader["train"] + val_dataloader = dataloader["val"] + test_dataloader = dataloader["test"] + + early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=2, + gpus=gpus, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + ) + + net = NHiTS.from_dataset( + train_dataloader.dataset, + learning_rate=0.15, + log_gradient_flow=True, + log_interval=1000, + hidden_size=8, + backcast_loss_ratio=1.0, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = NHiTS.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + + +@pytest.mark.parametrize("dataloader", ["with_covariates", "fixed_window_without_covariates", "multi_target"]) +def test_integration( + dataloaders_with_covariates, + dataloaders_fixed_window_without_covariates, + dataloaders_multi_target, + tmp_path, + gpus, + dataloader, +): + if dataloader == "with_covariates": + dataloader = dataloaders_with_covariates + elif dataloader == "fixed_window_without_covariates": + dataloader = dataloaders_fixed_window_without_covariates + elif dataloader == "multi_target": + dataloader = dataloaders_multi_target + else: + raise ValueError(f"dataloader {dataloader} unknown") + _integration(dataloader, tmp_path=tmp_path, gpus=gpus) + + +@pytest.fixture(scope="session") +def model(dataloaders_with_covariates): + dataset = dataloaders_with_covariates["train"].dataset + net = NHiTS.from_dataset( + dataset, + learning_rate=0.15, + hidden_size=8, + log_gradient_flow=True, + log_interval=1000, + backcast_loss_ratio=1.0, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl) + + +def test_interpretation(model, dataloaders_with_covariates): + raw_predictions, x = model.predict(dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True) + model.plot_prediction(x, raw_predictions, idx=0, add_loss_to_title=True) + model.plot_interpretation(x, raw_predictions, idx=0) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 43466512..ef81907b 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -271,8 +271,8 @@ def test_hyperparameter_optimization_integration(dataloaders_with_covariates, tm val_dataloader = dataloaders_with_covariates["val"] try: optimize_hyperparameters( - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, model_path=tmp_path, max_epochs=1, n_trials=3,