Skip to content

Commit

Permalink
Fixed wrong lr initialization when loading checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Nov 20, 2024
1 parent 7b7ead9 commit 5c4c0c6
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ def __init__(
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
# Update optimizer learning rate because otherwise it is set to zero in the first iteration.
param_groups = self.optimizer.get_base_optimizer().param_groups
last_lrs = self.lr_scheduler.get_last_lr()
assert len(param_groups) == len(last_lrs)
for group, last_lr in zip(param_groups, last_lrs):
assert "lr" in group
group["lr"] = last_lr

# Define iteration start state
if self.init_checkpoint_path is not None:
Expand Down

0 comments on commit 5c4c0c6

Please sign in to comment.