From e7eb2bf79c69bee1892f0b61e9957ab8f8f9563a Mon Sep 17 00:00:00 2001 From: hyunp2 <42776897+hyunp2@users.noreply.github.com> Date: Tue, 3 Oct 2023 09:14:47 -0500 Subject: [PATCH] Update 3_difflinker_train.py --- 3_difflinker_train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/3_difflinker_train.py b/3_difflinker_train.py index 4e90e295..758f55b2 100644 --- a/3_difflinker_train.py +++ b/3_difflinker_train.py @@ -11,6 +11,7 @@ from utils.src.lightning import DDPM from utils.src.utils import disable_rdkit_logging, Logger +from pytorch_lightning.callbacks import TQDMProgressBar def find_last_checkpoint(checkpoints_dir): epoch2fname = [ @@ -92,12 +93,12 @@ def main(args): anchors_context=anchors_context, ) print(args.test_epochs) - checkpoint_callback = callbacks.ModelCheckpoint( - dirpath=checkpoints_dir, - filename=experiment + '_{epoch:02d}', - monitor='loss/val', - save_top_k=10, - ) + checkpoint_callback = [callbacks.ModelCheckpoint( + dirpath=checkpoints_dir, + filename=experiment + '_{epoch:02d}', + monitor='loss/val', + save_top_k=10), + TQDMProgressBar()] trainer = Trainer( max_epochs=args.n_epochs, # logger=wandb_logger,