Skip to content

Commit

Permalink
Merge pull request #71 from jdb78/feature/add_future_prediction_in_tu…
Browse files Browse the repository at this point in the history
…torial

Explain how to predict into future
  • Loading branch information
jdb78 authored Sep 30, 2020
2 parents e7ab8b7 + 6327886 commit e97189f
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 74 deletions.
347 changes: 276 additions & 71 deletions docs/source/tutorials/stallion.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def plot_prediction(
out: Dict[str, torch.Tensor],
idx: int = 0,
add_loss_to_title: Union[TensorMetric, bool] = False,
show_future_observed: bool = True,
ax=None,
) -> plt.Figure:
"""
Expand All @@ -302,6 +303,7 @@ def plot_prediction(
out: network output
idx: index of prediction to plot
add_loss_to_title: if to add loss to title or loss function to calculate. Default to False.
show_future_observed: if to show actuals for future. Defaults to True.
ax: matplotlib axes to plot on
Returns:
Expand Down Expand Up @@ -346,8 +348,10 @@ def plot_prediction(
plotter = ax.plot
else:
plotter = ax.scatter

# plot observed prediction
plotter(x_pred, y[-n_pred:], label=None, c=obs_color)
if show_future_observed:
plotter(x_pred, y[-n_pred:], label=None, c=obs_color)

# plot prediction
plotter(x_pred, self.loss.to_prediction(y_hat.unsqueeze(0))[0], label="predicted", c=pred_color)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def plot_prediction(
idx: int,
plot_attention: bool = True,
add_loss_to_title: bool = False,
show_future_observed: bool = True,
ax=None,
) -> plt.Figure:
"""
Expand All @@ -735,14 +736,17 @@ def plot_prediction(
idx (int): sample index
plot_attention: if to plot attention on secondary axis
add_loss_to_title: if to add loss to title. Default to False.
show_future_observed: if to show actuals for future. Defaults to True.
ax: matplotlib axes to plot on
Returns:
plt.Figure: matplotlib figure
"""

# plot prediction as normal
fig = super().plot_prediction(x, out, idx=idx, add_loss_to_title=add_loss_to_title, ax=ax)
fig = super().plot_prediction(
x, out, idx=idx, add_loss_to_title=add_loss_to_title, show_future_observed=show_future_observed, ax=ax
)

# add attention on secondary axis
if plot_attention:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def optimize_hyperparameters(
use_learning_rate_finder: bool = True,
trainer_kwargs: Dict[str, Any] = {},
log_dir: str = "lightning_logs",
study: optuna.Study = None,
**kwargs,
) -> optuna.Study:
"""
Expand Down Expand Up @@ -77,6 +78,7 @@ def optimize_hyperparameters(
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`_ such
as ``limit_train_batches``. Defaults to {}.
log_dir (str, optional): Folder into which to log results for tensorboard. Defaults to "lightning_logs".
study (optuna.Study, optional): study to resume. Will create new study by default.
**kwargs: Additional arguments for the :py:class:`~TemporalFusionTransformer`.
Returns:
Expand Down Expand Up @@ -165,6 +167,7 @@ def objective(trial: optuna.Trial) -> float:

# setup optuna and run
pruner = optuna.pruners.SuccessiveHalvingPruner()
study = optuna.create_study(direction="minimize", pruner=pruner)
if study is None:
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=n_trials, timeout=timeout)
return study

0 comments on commit e97189f

Please sign in to comment.