Skip to content

Commit

Permalink
Merge pull request #340 from Noble-Lab/313-save-final-model
Browse files Browse the repository at this point in the history
313 save final model
  • Loading branch information
Lilferrit authored Jun 28, 2024
2 parents 70ea9fc + a743dc5 commit 7372eb0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Added

- During training, model checkpoints will now be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run.

### Fixed

- Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.
Expand Down
15 changes: 10 additions & 5 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,22 @@ def __init__(
self.writer = None

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]

if config.save_top_k is not None:
self.callbacks = [
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
]
else:
self.callbacks = None
)

def __enter__(self):
"""Enter the context manager"""
Expand Down Expand Up @@ -187,7 +192,7 @@ def initialize_trainer(self, train: bool) -> None:
additional_cfg = dict(
devices=devices,
callbacks=self.callbacks,
enable_checkpointing=self.config.save_top_k is not None,
enable_checkpointing=True,
max_epochs=self.config.max_epochs,
num_sanity_val_steps=self.config.num_sanity_val_steps,
strategy=self._get_strategy(),
Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests specifically for the model_runner module."""

from pathlib import Path

import pytest
import torch

Expand All @@ -10,6 +12,7 @@
def test_initialize_model(tmp_path, mgf_small):
"""Test initializing a new or existing model."""
config = Config()
config.model_save_folder_path = tmp_path
# No model filename given, so train from scratch.
ModelRunner(config=config).initialize_model(train=True)

Expand Down Expand Up @@ -149,3 +152,25 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config):

assert "valid_aa_precision" in runner.model.history.columns
assert "valid_pep_precision" in runner.model.history.columns


def test_save_final_model(tmp_path, mgf_small, tiny_config):
"""Test that final model checkpoints are saved."""
# Test checkpoint saving when val_check_interval is greater than training steps
config = Config(tiny_config)
config.val_check_interval = 50
model_file = tmp_path / "epoch=19-step=20.ckpt"
with ModelRunner(config) as runner:
runner.train([mgf_small], [mgf_small])

assert model_file.exists()
Path.unlink(model_file)

# Test checkpoint saving when val_check_interval is not a factor of training steps
config.val_check_interval = 15
validation_file = tmp_path / "epoch=14-step=15.ckpt"
with ModelRunner(config) as runner:
runner.train([mgf_small], [mgf_small])

assert model_file.exists()
assert validation_file.exists()

0 comments on commit 7372eb0

Please sign in to comment.