Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
svnv-svsv-jm committed Jan 3, 2024
1 parent dccdd22 commit 5fc682a
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
BaseModelWithCovariates,
DecoderMLP,
DeepAR,
LSTMModel,
MultiEmbedding,
NBeats,
NHiTS,
RecurrentNetwork,
TemporalFusionTransformer,
LSTMModel,
get_rnn,
)
from pytorch_forecasting.utils import (
Expand Down
1 change: 1 addition & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
from pytorch_forecasting.models.rnn import RecurrentNetwork
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

from .lstm import LSTMModel

__all__ = [
Expand Down
32 changes: 17 additions & 15 deletions pytorch_forecasting/models/_base_autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
__all__ = ["AutoRegressiveBaseModel"]

from loguru import logger
from typing import List, Union, Any, Sequence, Tuple, Dict, Callable
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from loguru import logger
import torch
from torch import Tensor

from pytorch_forecasting.metrics import MultiLoss, DistributionLoss
from pytorch_forecasting.utils import to_list, apply_to_list
from pytorch_forecasting.metrics import DistributionLoss, MultiLoss
from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_
from pytorch_forecasting.utils import apply_to_list, to_list


class AutoRegressiveBaseModel(AutoRegressiveBaseModel_): # pylint: disable=abstract-method
Expand Down Expand Up @@ -39,9 +39,7 @@ def output_to_prediction(
single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2
logger.trace(f"single_prediction={single_prediction}")
if single_prediction: # add time dimension as it is expected
normalized_prediction_parameters = apply_to_list(
normalized_prediction_parameters, lambda x: x.unsqueeze(1)
)
normalized_prediction_parameters = apply_to_list(normalized_prediction_parameters, lambda x: x.unsqueeze(1))
# transform into real space
prediction_parameters = self.transform_output(
prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs
Expand Down Expand Up @@ -95,12 +93,17 @@ def decode_autoregressive(
"""
Make predictions in auto-regressive manner. Supports only continuous targets.
Args:
decode_one (Callable): function that takes at least the following arguments:
decode_one (Callable):
function that takes at least the following arguments:
* ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1)
* ``lagged_targets`` (List[torch.Tensor]): list of normalized targets.
List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included.
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step.
List is ``idx + 1`` elements long with the most recent entry at the end, i.e.
``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable
names. Only lags that are greater than ``idx`` are included.
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And
returns tuple of (not rescaled) network prediction output and hidden state for next
auto-regressive step.
first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding
first_hidden_state (Any): first hidden state used for decoding
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x``
Expand Down Expand Up @@ -130,9 +133,7 @@ def decode_autoregressive(
if isinstance(prediction, Tensor):
logger.trace(f"prediction ({type(prediction)}): {prediction.size()}")
else:
logger.trace(
f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}"
)
logger.trace(f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}")
# save normalized output for lagged targets
normalized_output.append(current_target)
# set output to unnormalized samples, append each target as n_batch_samples x n_random_samples
Expand All @@ -159,7 +160,8 @@ def decode_autoregressive(
logger.trace(f"final_output_multitarget: {final_output_multitarget.size()}")
else:
logger.trace(
f"final_output_multitarget ({type(final_output_multitarget)}): {[o.size() for o in final_output_multitarget]}"
f"final_output_multitarget ({type(final_output_multitarget)})"
f"{[o.size() for o in final_output_multitarget]}"
)
r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))]
return r
9 changes: 5 additions & 4 deletions pytorch_forecasting/models/lstm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__all__ = ["LSTMModel"]

from loguru import logger
from typing import List, Union, Any, Sequence, Tuple, Dict
from typing import Any, Dict, List, Sequence, Tuple, Union

from loguru import logger
import torch
from torch import nn, Tensor
from torch import Tensor, nn

from pytorch_forecasting.metrics import MAE, Metric, MultiLoss
from pytorch_forecasting.models.nn import LSTM
Expand Down Expand Up @@ -40,7 +40,8 @@ def __init__(
input_size (int, optional):
Input size. Defaults to: inferred from `target`.
loss (Metric):
Loss criterion. Can be different for each target in multi-target setting thanks to `MultiLoss`. Defaults to `MAE`.
Loss criterion. Can be different for each target in multi-target setting thanks to
`MultiLoss`. Defaults to `MAE`.
**kwargs:
See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`.
"""
Expand Down
39 changes: 25 additions & 14 deletions pytorch_forecasting/models/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

__all__ = ["optimize_hyperparameters"]

from loguru import logger
import copy
import logging
import os
from typing import Any, Dict, Tuple, Union, Optional, Callable, Type, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union

import lightning.pytorch as pl
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.tuner.lr_finder import _LRFinder
from loguru import logger
import numpy as np
import optuna
from optuna import Trial
Expand All @@ -24,7 +24,7 @@
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_forecasting import TemporalFusionTransformer, BaseModel
from pytorch_forecasting import BaseModel, TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet

optuna_logger = logging.getLogger("optuna")
Expand Down Expand Up @@ -87,7 +87,8 @@ def optimize_hyperparameters(
**kwargs: Any,
) -> optuna.Study:
"""
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the PyTorch Lightning learning rate finder.
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
PyTorch Lightning learning rate finder.
Args:
train_dataloaders (DataLoader):
Expand All @@ -97,30 +98,37 @@ def optimize_hyperparameters(
model_path (str):
Folder to which model checkpoints are saved.
monitor (str):
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and reads this metric to score configuration. By default, the lower the better.
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
reads this metric to score configuration. By default, the lower the better.
direction (str):
By default, direction is "minimize", meaning that lower values of the specified `monitor` are better. You can change this, e.g. to "maximize".
By default, direction is "minimize", meaning that lower values of the specified `monitor` are
better. You can change this, e.g. to "maximize".
max_epochs (int, optional):
Maximum number of epochs to run training. Defaults to 20.
n_trials (int, optional):
Number of hyperparameter trials to run. Defaults to 100.
timeout (float, optional):
Time in seconds after which training is stopped regardless of number of epochs or validation metric. Defaults to 3600*8.0.
Time in seconds after which training is stopped regardless of number of epochs or validation
metric. Defaults to 3600*8.0.
input_params (dict, optional):
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and `"ranges"`. Example:
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
`"ranges"`. Example:
>>> {"hidden_size": {
>>> "method": "suggest_int",
>>> "ranges": (16, 265),
>>> }}
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input ranges for the specified method.
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
ranges for the specified method.
input_params_generator (Callable, optional):
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`, returning the parameter values to set up your model for the current trial/run.
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
`, returning the parameter values to set up your model for the current trial/run.
Example:
>>> def fn(trial: optuna.Trial, param_ranges: Tuple[int, int] = (16, 265)) -> Dict[str, Any]:
>>> param = trial.suggest_int("param", *param_ranges, log=True)
>>> model_params = {"param": param}
>>> return model_params
Then, when your model is created (before training it and report the metrics for the current combination of hyperparameters), these dictionary is used as follows:
Then, when your model is created (before training it and report the metrics for the current
combination of hyperparameters), these dictionary is used as follows:
>>> model = YourModelClass.from_dataset(
>>> train_dataloaders.dataset,
>>> log_interval=-1,
Expand All @@ -133,7 +141,9 @@ def optimize_hyperparameters(
use_learning_rate_finder (bool):
If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
trainer_kwargs (Dict[str, Any], optional):
Additional arguments to the `PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>` such as `limit_train_batches`. Defaults to {}.
Additional arguments to the
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`
such as `limit_train_batches`. Defaults to {}.
log_dir (str, optional):
Folder into which to log results for tensorboard. Defaults to "lightning_logs".
study (optuna.Study, optional):
Expand All @@ -153,6 +163,9 @@ def optimize_hyperparameters(
Returns:
optuna.Study: optuna study results
"""
if generator_params is None:
generator_params = {}

assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance(
val_dataloaders.dataset, TimeSeriesDataSet
), "Dataloaders must be built from TimeSeriesDataSet."
Expand Down Expand Up @@ -209,8 +222,6 @@ def objective(trial: optuna.Trial) -> float:
except ValueError as ex:
raise ValueError(f"Error while calling {fn} for {key}.") from ex
else:
if generator_params is None:
generator_params = {}
params = input_params_generator(trial, **generator_params)
kwargs.update(params)
kwargs["loss"] = copy.deepcopy(loss)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys

import pandas as pd
import numpy as np
import pandas as pd
import pytest

sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models/test_tuning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import sys, os
import os
import sys
import typing as ty

from loguru import logger
import pytest

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.models import LSTMModel
Expand Down

0 comments on commit 5fc682a

Please sign in to comment.