Skip to content

Commit

Permalink
Update 3_difflinker_train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunp2 authored Oct 3, 2023
1 parent 08c6588 commit e7eb2bf
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions 3_difflinker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from utils.src.lightning import DDPM
from utils.src.utils import disable_rdkit_logging, Logger

from pytorch_lightning.callbacks import TQDMProgressBar

def find_last_checkpoint(checkpoints_dir):
epoch2fname = [
Expand Down Expand Up @@ -92,12 +93,12 @@ def main(args):
anchors_context=anchors_context,
)
print(args.test_epochs)
checkpoint_callback = callbacks.ModelCheckpoint(
dirpath=checkpoints_dir,
filename=experiment + '_{epoch:02d}',
monitor='loss/val',
save_top_k=10,
)
checkpoint_callback = [callbacks.ModelCheckpoint(
dirpath=checkpoints_dir,
filename=experiment + '_{epoch:02d}',
monitor='loss/val',
save_top_k=10),
TQDMProgressBar()]
trainer = Trainer(
max_epochs=args.n_epochs,
# logger=wandb_logger,
Expand Down

0 comments on commit e7eb2bf

Please sign in to comment.