Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

313 save final model #340

Merged
merged 11 commits into from
Jun 28, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Added

- During training, model checkpoints will now be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run.

### Fixed

- Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.
Expand Down
15 changes: 10 additions & 5 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,22 @@ def __init__(
self.writer = None

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]

if config.save_top_k is not None:
self.callbacks = [
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
]
else:
self.callbacks = None
)

def __enter__(self):
"""Enter the context manager"""
Expand Down Expand Up @@ -187,7 +192,7 @@ def initialize_trainer(self, train: bool) -> None:
additional_cfg = dict(
devices=devices,
callbacks=self.callbacks,
enable_checkpointing=self.config.save_top_k is not None,
enable_checkpointing=True,
max_epochs=self.config.max_epochs,
num_sanity_val_steps=self.config.num_sanity_val_steps,
strategy=self._get_strategy(),
Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests specifically for the model_runner module."""

from pathlib import Path

import pytest
import torch

Expand All @@ -10,6 +12,7 @@
def test_initialize_model(tmp_path, mgf_small):
"""Test initializing a new or existing model."""
config = Config()
config.model_save_folder_path = tmp_path
# No model filename given, so train from scratch.
ModelRunner(config=config).initialize_model(train=True)

Expand Down Expand Up @@ -149,3 +152,25 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config):

assert "valid_aa_precision" in runner.model.history.columns
assert "valid_pep_precision" in runner.model.history.columns


def test_save_final_model(tmp_path, mgf_small, tiny_config):
"""Test that final model checkpoints are saved."""
# Test checkpoint saving when val_check_interval is greater than training steps
config = Config(tiny_config)
config.val_check_interval = 50
model_file = tmp_path / "epoch=19-step=20.ckpt"
with ModelRunner(config) as runner:
runner.train([mgf_small], [mgf_small])

assert model_file.exists()
Path.unlink(model_file)

# Test checkpoint saving when val_check_interval is not a factor of training steps
config.val_check_interval = 15
validation_file = tmp_path / "epoch=14-step=15.ckpt"
with ModelRunner(config) as runner:
runner.train([mgf_small], [mgf_small])

assert model_file.exists()
assert validation_file.exists()
Loading