diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..9a01b7a1 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -209,6 +209,13 @@ def __init__( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) + # Update optimizer learning rate because otherwise it is set to zero in the first iteration. + param_groups = self.optimizer.get_base_optimizer().param_groups + last_lrs = self.lr_scheduler.get_last_lr() + assert len(param_groups) == len(last_lrs) + for group, last_lr in zip(param_groups, last_lrs): + assert "lr" in group + group["lr"] = last_lr # Define iteration start state if self.init_checkpoint_path is not None: