From 203005275b4e746f554da2668a67feca45d3b746 Mon Sep 17 00:00:00 2001 From: Wout Bittremieux Date: Mon, 25 Dec 2023 12:18:44 +0100 Subject: [PATCH] Remove `train_from_scratch` config option Instead of having to specify `train_from_scratch` in the config file, training will proceed from an existing model weights file if this is given as an argument to `casanovo train`. Fixes #263. --- CHANGELOG.md | 4 ++++ casanovo/config.py | 1 - casanovo/config.yaml | 2 -- casanovo/denovo/model_runner.py | 20 +++++++++--------- tests/conftest.py | 1 - tests/unit_tests/test_runner.py | 36 ++++++++++++++++++--------------- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbc9284e..84fadfec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +### Changed + +- Instead of having to specify `train_from_scratch` in the config file, training will proceed from an existing model weights file if this is given as an argument to `casanovo train`. + ## [4.0.0] - 2023-12-22 ### Added diff --git a/casanovo/config.py b/casanovo/config.py index 0b5a1e4d..22924018 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -64,7 +64,6 @@ class Config: top_match=int, max_epochs=int, num_sanity_val_steps=int, - train_from_scratch=bool, save_top_k=int, model_save_folder_path=str, val_check_interval=int, diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 896f67bc..24bf4623 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -99,8 +99,6 @@ train_batch_size: 32 max_epochs: 30 # Number of validation steps to run before training begins num_sanity_val_steps: 0 -# Set to "False" to further train a pre-trained Casanovo model -train_from_scratch: True # Calculate peptide and amino acid precision during training. this # is expensive, so we recommend against it. calculate_precision: False diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index c7a9cab6..b632227b 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -251,16 +251,16 @@ def initialize_model(self, train: bool) -> None: calculate_precision=self.config.calculate_precision, ) - from_scratch = ( - self.config.train_from_scratch, - self.model_filename is None, - ) - if train and any(from_scratch): - self.model = Spec2Pep(**model_params) - return - elif self.model_filename is None: - logger.error("A model file must be provided") - raise ValueError("A model file must be provided") + if self.model_filename is None: + # Train a model from scratch if no model file is provided. + if train: + self.model = Spec2Pep(**model_params) + return + # Else we're not training, so a model file must be provided. + else: + logger.error("A model file must be provided") + raise ValueError("A model file must be provided") + # Else a model file is provided (to continue training or for inference). if not Path(self.model_filename).exists(): logger.error( diff --git a/tests/conftest.py b/tests/conftest.py index a690bd8a..f1918300 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -222,7 +222,6 @@ def tiny_config(tmp_path): "weight_decay": 1e-5, "train_batch_size": 32, "num_sanity_val_steps": 0, - "train_from_scratch": True, "calculate_precision": False, "residues": { "G": 57.021464, diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 6be91831..a670acad 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -6,35 +6,39 @@ from casanovo.denovo.model_runner import ModelRunner -def test_initialize_model(tmp_path): - """Test that""" +def test_initialize_model(tmp_path, mgf_small): + """Test initializing a new or existing model.""" config = Config() - config.train_from_scratch = False + # No model filename given, so train from scratch. ModelRunner(config=config).initialize_model(train=True) + # No model filename given during inference = error. with pytest.raises(ValueError): ModelRunner(config=config).initialize_model(train=False) - with pytest.raises(FileNotFoundError): - runner = ModelRunner(config=config, model_filename="blah") - runner.initialize_model(train=True) - + # Non-existing model filename given during inference = error. with pytest.raises(FileNotFoundError): runner = ModelRunner(config=config, model_filename="blah") runner.initialize_model(train=False) - # This should work now: - config.train_from_scratch = True - runner = ModelRunner(config=config, model_filename="blah") + # Train a quick model. + config.max_epochs = 1 + config.n_layers = 1 + ckpt = tmp_path / "existing.ckpt" + with ModelRunner(config=config) as runner: + runner.train([mgf_small], [mgf_small]) + runner.trainer.save_checkpoint(ckpt) + + # Resume training from previous model. + runner = ModelRunner(config=config, model_filename=str(ckpt)) runner.initialize_model(train=True) - # But this should still fail: - with pytest.raises(FileNotFoundError): - runner = ModelRunner(config=config, model_filename="blah") - runner.initialize_model(train=False) + # Inference with previous model. + runner = ModelRunner(config=config, model_filename=str(ckpt)) + runner.initialize_model(train=False) # If the model initialization throws and EOFError, then the Spec2Pep model - # has tried to load the weights: + # has tried to load the weights. weights = tmp_path / "blah" weights.touch() with pytest.raises(EOFError): @@ -43,7 +47,7 @@ def test_initialize_model(tmp_path): def test_save_and_load_weights(tmp_path, mgf_small, tiny_config): - """Test saving aloading weights""" + """Test saving and loading weights""" config = Config(tiny_config) config.max_epochs = 1 config.n_layers = 1