From 0878cbfc8e4727074f9c0e47f7f8ca077c746416 Mon Sep 17 00:00:00 2001 From: Simon Date: Mon, 16 Oct 2023 15:47:00 -0400 Subject: [PATCH] change logging to be dependent on current step (#238) * remove unused accelerate kwargs helper * fix logging to global_step not epoch * changing loss freq to 50 from 10 --- src/dnadiffusion/utils/train_util.py | 20 +++++++++++--------- train.py | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/dnadiffusion/utils/train_util.py b/src/dnadiffusion/utils/train_util.py index ac218135..25bf7278 100644 --- a/src/dnadiffusion/utils/train_util.py +++ b/src/dnadiffusion/utils/train_util.py @@ -21,7 +21,7 @@ def __init__( model: torch.nn.Module, accelerator: Accelerator, epochs: int = 10000, - loss_show_epoch: int = 10, + log_step_show: int = 50, sample_epoch: int = 500, save_epoch: int = 500, model_name: str = "model_48k_sequences_per_group_K562_hESCT0_HepG2_GM12878_12k", @@ -34,7 +34,7 @@ def __init__( self.optimizer = Adam(self.model.parameters(), lr=1e-4) self.accelerator = accelerator self.epochs = epochs - self.loss_show_epoch = loss_show_epoch + self.log_step_show = log_step_show self.sample_epoch = sample_epoch self.save_epoch = save_epoch self.model_name = model_name @@ -70,12 +70,14 @@ def train_loop(self): self.model.train() # Getting loss of current batch - for _, batch in enumerate(self.train_dl): + for step, batch in enumerate(self.train_dl): + self.global_step = epoch * len(self.train_dl) + step + loss = self.train_step(batch) - # Logging loss - if epoch % self.loss_show_epoch == 0 and self.accelerator.is_main_process: - self.log_step(loss, epoch) + # Logging loss + if self.global_step % self.log_step_show == 0 and self.accelerator.is_main_process: + self.log_step(loss, epoch) # Sampling if epoch % self.sample_epoch == 0 and self.accelerator.is_main_process: @@ -110,12 +112,12 @@ def log_step(self, loss, epoch): "train": self.train_kl, "test": self.test_kl, "shuffle": self.shuffle_kl, - "loss": loss.item(), + "loss": loss.mean().item(), + "epoch": epoch, "seq_similarity": self.seq_similarity, }, - step=epoch, + step=self.global_step, ) - print(f" Epoch {epoch} Loss:", loss.item()) def sample(self): self.model.eval() diff --git a/train.py b/train.py index d664a5e8..b129c1cf 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate import Accelerator from dnadiffusion.data.dataloader import load_data from dnadiffusion.models.diffusion import Diffusion @@ -40,7 +40,7 @@ def train(): model=diffusion, accelerator=accelerator, epochs=10000, - loss_show_epoch=10, + log_step_show=50, sample_epoch=500, save_epoch=500, model_name="model_48k_sequences_per_group_K562_hESCT0_HepG2_GM12878_12k",