From a034f1f59cbae24a958fee6cb714c5634aa23a6a Mon Sep 17 00:00:00 2001 From: Lilferrit Date: Tue, 25 Jun 2024 10:51:55 -0700 Subject: [PATCH] save final model using ModelCheckpoint callback --- casanovo/denovo/model_runner.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 46f086e9..7e3804ef 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -55,17 +55,22 @@ def __init__( self.writer = None # Configure checkpoints. + self.callbacks = [ + ModelCheckpoint( + dirpath=config.model_save_folder_path, + save_on_train_epoch_end=True, + ) + ] + if config.save_top_k is not None: - self.callbacks = [ + self.callbacks.append( ModelCheckpoint( dirpath=config.model_save_folder_path, monitor="valid_CELoss", mode="min", save_top_k=config.save_top_k, ) - ] - else: - self.callbacks = None + ) def __enter__(self): """Enter the context manager""" @@ -111,15 +116,6 @@ def train( self.loaders.val_dataloader(), ) - # Always save final model weights at the end of training - if self.config.model_save_folder_path is not None: - self.trainer.save_checkpoint( - os.path.join( - self.config.model_save_folder_path, - f"train-run-final-{self.trainer.current_epoch}.ckpt", - ) - ) - def evaluate(self, peak_path: Iterable[str]) -> None: """Evaluate peptide sequence preditions from a trained Casanovo model.