Skip to content

Commit

Permalink
monitor time, inspired from finegrain-ai#278
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 13, 2024
1 parent 52f0fec commit a32b023
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/refiners/training_utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from loguru import logger
from torch import tensor
from torch.nn import Parameter
import time

if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig
Expand Down Expand Up @@ -177,6 +178,74 @@ def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})

# Ported from open-muse
class AverageTimeMeter(object):
"""Computes and stores the average and current value"""

def __init__(self):
self.reset()

def reset(self):
self.avg: float = 0
self.sum: float = 0
self.count: int = 0

def update(self, val: float):
self.sum += val
self.count += 1
self.avg = self.sum / self.count

def start(self):
self.start_time = time.time()

def cancel(self):
self.start_time = None

def stop(self):
if self.start_time is None:
return
spent = time.time() - self.start_time
self.update(spent)
self.start_time = None


class MonitorTime(Callback["Trainer[BaseConfig, Any]"]):
def __init__(self) -> None:
self.forward = AverageTimeMeter()
self.data = AverageTimeMeter()
self.backward = AverageTimeMeter()

super().__init__()

def on_compute_loss_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.forward.start()

def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.forward.stop()

def on_backward_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.backward.start()

def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.backward.stop()

def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.data.stop()

def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={
"timing/forward_time per item": self.forward.avg / trainer.clock.batch_size,
"timing/backward_time per item": self.backward.avg / trainer.clock.batch_size,
"timing/data_time per item": self.data.avg / trainer.clock.batch_size
})
self.data.start()

def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.data.start()

def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.data.cancel()


class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
Expand Down

0 comments on commit a32b023

Please sign in to comment.