Skip to content

Commit

Permalink
save final model using ModelCheckpoint callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Jun 25, 2024
1 parent b57ea7d commit a034f1f
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,22 @@ def __init__(
self.writer = None

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]

if config.save_top_k is not None:
self.callbacks = [
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
]
else:
self.callbacks = None
)

def __enter__(self):
"""Enter the context manager"""
Expand Down Expand Up @@ -111,15 +116,6 @@ def train(
self.loaders.val_dataloader(),
)

# Always save final model weights at the end of training
if self.config.model_save_folder_path is not None:
self.trainer.save_checkpoint(
os.path.join(
self.config.model_save_folder_path,
f"train-run-final-{self.trainer.current_epoch}.ckpt",
)
)

def evaluate(self, peak_path: Iterable[str]) -> None:
"""Evaluate peptide sequence preditions from a trained Casanovo model.
Expand Down

0 comments on commit a034f1f

Please sign in to comment.