diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index 8e412984..1a0c33c0 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -1,7 +1,7 @@ import shutil import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import NBeats @@ -12,7 +12,9 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path): early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") logger = TensorBoardLogger(tmp_path) + checkpoint = ModelCheckpoint(filepath=tmp_path) trainer = pl.Trainer( + checkpoint_callback=checkpoint, max_epochs=3, gpus=0, weights_summary="top", @@ -28,6 +30,12 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path): trainer.fit( net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, ) + # check loading + fname = f"{trainer.checkpoint_callback.dirpath}/epoch=0.ckpt" + net = NBeats.load_from_checkpoint(fname) + + # check prediction + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) finally: shutil.rmtree(tmp_path, ignore_errors=True) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 80531e7f..6d110f23 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -1,7 +1,7 @@ import shutil import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import TemporalFusionTransformer @@ -17,7 +17,9 @@ def test_integration(dataloaders_with_coveratiates, tmp_path): # check training logger = TensorBoardLogger(tmp_path) + checkpoint = ModelCheckpoint(filepath=tmp_path) trainer = pl.Trainer( + checkpoint_callback=checkpoint, max_epochs=3, gpus=0, weights_summary="top",