Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

313 save final model #340

Merged
merged 11 commits into from
Jun 28, 2024
15 changes: 10 additions & 5 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
@@ -55,17 +55,22 @@ def __init__(
self.writer = None

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]

if config.save_top_k is not None:
self.callbacks = [
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
]
else:
self.callbacks = None
)

def __enter__(self):
"""Enter the context manager"""
@@ -187,7 +192,7 @@ def initialize_trainer(self, train: bool) -> None:
additional_cfg = dict(
devices=devices,
callbacks=self.callbacks,
enable_checkpointing=self.config.save_top_k is not None,
enable_checkpointing=True,
max_epochs=self.config.max_epochs,
num_sanity_val_steps=self.config.num_sanity_val_steps,
strategy=self._get_strategy(),
36 changes: 32 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import psims
import pytest
import yaml
import math
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
from pyteomics.mass import calculate_mass


@@ -184,10 +185,9 @@ def _create_mzml(peptides, mzml_file, random_state=42):
return mzml_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
cfg = {
def _get_default_config(tmp_path):
"""Get default test config (dictionary)"""
return {
"n_head": 2,
"dim_feedforward": 10,
"n_layers": 1,
@@ -255,8 +255,36 @@ def tiny_config(tmp_path):
},
}


def _write_config_file(cfg, tmp_path):
"""Write config file to temp directory"""
cfg_file = tmp_path / "config.yml"
with cfg_file.open("w+") as out_file:
yaml.dump(cfg, out_file)

return cfg_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
cfg = _get_default_config(tmp_path)
return _write_config_file(cfg, tmp_path)


@pytest.fixture
def tiny_config_interval_greater(tmp_path):
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
"""Config file where val_check interval is greater than the number of training steps"""
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg = _get_default_config(tmp_path)
val_check_interval = 50
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg["val_check_interval"] = val_check_interval
return _write_config_file(cfg, tmp_path)


@pytest.fixture
def tiny_config_not_factor(tmp_path):
"""Config file where val_check interval isn't a factor of the number of training steps"""
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg = _get_default_config(tmp_path)
val_check_interval = 15
cfg["val_check_interval"] = val_check_interval
return _write_config_file(cfg, tmp_path)
46 changes: 45 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,13 @@


def test_train_and_run(
mgf_small, mzml_small, tiny_config, tmp_path, monkeypatch
mgf_small,
mzml_small,
tiny_config,
tmp_path,
tiny_config_interval_greater,
tiny_config_not_factor,
monkeypatch,
):
# We can use this to explicitly test different versions.
monkeypatch.setattr(casanovo, "__version__", "3.0.1")
@@ -86,6 +92,44 @@ def test_train_and_run(
assert psms.loc[4, "sequence"] == "PEPTLDEK"
assert psms.loc[4, "spectra_ref"] == "ms_run[2]:scan=111"

# Test checkpoint saving when val_check_interval is greater than training steps
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
Path.unlink(model_file)
result = run(
[
"train",
"--validation_peak_path",
str(mgf_small),
"--config",
tiny_config_interval_greater,
"--output",
str(tmp_path / "train"),
str(mgf_small),
]
)

assert result.exit_code == 0
assert model_file.exists()

# Test checkpoint saving when val_check_interval is not a factor of training steps
Path.unlink(model_file)
validation_file = tmp_path / "epoch=14-step=15.ckpt"
result = run(
[
"train",
"--validation_peak_path",
str(mgf_small),
"--config",
tiny_config_not_factor,
"--output",
str(tmp_path / "train"),
str(mgf_small),
]
)

assert result.exit_code == 0
assert model_file.exists()
assert validation_file.exists()


def test_auxilliary_cli(tmp_path, monkeypatch):
"""Test the secondary CLI commands"""
1 change: 1 addition & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
def test_initialize_model(tmp_path, mgf_small):
"""Test initializing a new or existing model."""
config = Config()
config.model_save_folder_path = tmp_path
# No model filename given, so train from scratch.
ModelRunner(config=config).initialize_model(train=True)