Skip to content

Commit

Permalink
Fix model tests: unique directory for checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Aug 15, 2020
1 parent 53ab0b2 commit e909fc4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
10 changes: 9 additions & 1 deletion tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import NBeats

Expand All @@ -12,7 +12,9 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path):
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

logger = TensorBoardLogger(tmp_path)
checkpoint = ModelCheckpoint(filepath=tmp_path)
trainer = pl.Trainer(
checkpoint_callback=checkpoint,
max_epochs=3,
gpus=0,
weights_summary="top",
Expand All @@ -28,6 +30,12 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path):
trainer.fit(
net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)
# check loading
fname = f"{trainer.checkpoint_callback.dirpath}/epoch=0.ckpt"
net = NBeats.load_from_checkpoint(fname)

# check prediction
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import TemporalFusionTransformer

Expand All @@ -17,7 +17,9 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):

# check training
logger = TensorBoardLogger(tmp_path)
checkpoint = ModelCheckpoint(filepath=tmp_path)
trainer = pl.Trainer(
checkpoint_callback=checkpoint,
max_epochs=3,
gpus=0,
weights_summary="top",
Expand Down

0 comments on commit e909fc4

Please sign in to comment.