Skip to content

Commit

Permalink
change logging to be dependent on current step (#238)
Browse files Browse the repository at this point in the history
* remove unused accelerate kwargs helper

* fix logging to global_step not epoch

* changing loss freq to 50 from 10
  • Loading branch information
ssenan authored Oct 16, 2023
1 parent 61fadc1 commit 0878cbf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
20 changes: 11 additions & 9 deletions src/dnadiffusion/utils/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
model: torch.nn.Module,
accelerator: Accelerator,
epochs: int = 10000,
loss_show_epoch: int = 10,
log_step_show: int = 50,
sample_epoch: int = 500,
save_epoch: int = 500,
model_name: str = "model_48k_sequences_per_group_K562_hESCT0_HepG2_GM12878_12k",
Expand All @@ -34,7 +34,7 @@ def __init__(
self.optimizer = Adam(self.model.parameters(), lr=1e-4)
self.accelerator = accelerator
self.epochs = epochs
self.loss_show_epoch = loss_show_epoch
self.log_step_show = log_step_show
self.sample_epoch = sample_epoch
self.save_epoch = save_epoch
self.model_name = model_name
Expand Down Expand Up @@ -70,12 +70,14 @@ def train_loop(self):
self.model.train()

# Getting loss of current batch
for _, batch in enumerate(self.train_dl):
for step, batch in enumerate(self.train_dl):
self.global_step = epoch * len(self.train_dl) + step

loss = self.train_step(batch)

# Logging loss
if epoch % self.loss_show_epoch == 0 and self.accelerator.is_main_process:
self.log_step(loss, epoch)
# Logging loss
if self.global_step % self.log_step_show == 0 and self.accelerator.is_main_process:
self.log_step(loss, epoch)

# Sampling
if epoch % self.sample_epoch == 0 and self.accelerator.is_main_process:
Expand Down Expand Up @@ -110,12 +112,12 @@ def log_step(self, loss, epoch):
"train": self.train_kl,
"test": self.test_kl,
"shuffle": self.shuffle_kl,
"loss": loss.item(),
"loss": loss.mean().item(),
"epoch": epoch,
"seq_similarity": self.seq_similarity,
},
step=epoch,
step=self.global_step,
)
print(f" Epoch {epoch} Loss:", loss.item())

def sample(self):
self.model.eval()
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate import Accelerator

from dnadiffusion.data.dataloader import load_data
from dnadiffusion.models.diffusion import Diffusion
Expand Down Expand Up @@ -40,7 +40,7 @@ def train():
model=diffusion,
accelerator=accelerator,
epochs=10000,
loss_show_epoch=10,
log_step_show=50,
sample_epoch=500,
save_epoch=500,
model_name="model_48k_sequences_per_group_K562_hESCT0_HepG2_GM12878_12k",
Expand Down

0 comments on commit 0878cbf

Please sign in to comment.