diff --git a/CHANGELOG.md b/CHANGELOG.md index cee00b8e..1824cc1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +### Added + +- During training, model checkpoints will now be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run. + ### Fixed - Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification. diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 4bd2165e..d5acacb3 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -55,17 +55,22 @@ def __init__( self.writer = None # Configure checkpoints. + self.callbacks = [ + ModelCheckpoint( + dirpath=config.model_save_folder_path, + save_on_train_epoch_end=True, + ) + ] + if config.save_top_k is not None: - self.callbacks = [ + self.callbacks.append( ModelCheckpoint( dirpath=config.model_save_folder_path, monitor="valid_CELoss", mode="min", save_top_k=config.save_top_k, ) - ] - else: - self.callbacks = None + ) def __enter__(self): """Enter the context manager""" @@ -187,7 +192,7 @@ def initialize_trainer(self, train: bool) -> None: additional_cfg = dict( devices=devices, callbacks=self.callbacks, - enable_checkpointing=self.config.save_top_k is not None, + enable_checkpointing=True, max_epochs=self.config.max_epochs, num_sanity_val_steps=self.config.num_sanity_val_steps, strategy=self._get_strategy(), diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 7febf3f7..2d0513bd 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -1,5 +1,7 @@ """Unit tests specifically for the model_runner module.""" +from pathlib import Path + import pytest import torch @@ -10,6 +12,7 @@ def test_initialize_model(tmp_path, mgf_small): """Test initializing a new or existing model.""" config = Config() + config.model_save_folder_path = tmp_path # No model filename given, so train from scratch. ModelRunner(config=config).initialize_model(train=True) @@ -149,3 +152,25 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config): assert "valid_aa_precision" in runner.model.history.columns assert "valid_pep_precision" in runner.model.history.columns + + +def test_save_final_model(tmp_path, mgf_small, tiny_config): + """Test that final model checkpoints are saved.""" + # Test checkpoint saving when val_check_interval is greater than training steps + config = Config(tiny_config) + config.val_check_interval = 50 + model_file = tmp_path / "epoch=19-step=20.ckpt" + with ModelRunner(config) as runner: + runner.train([mgf_small], [mgf_small]) + + assert model_file.exists() + Path.unlink(model_file) + + # Test checkpoint saving when val_check_interval is not a factor of training steps + config.val_check_interval = 15 + validation_file = tmp_path / "epoch=14-step=15.ckpt" + with ModelRunner(config) as runner: + runner.train([mgf_small], [mgf_small]) + + assert model_file.exists() + assert validation_file.exists()