Skip to content

Commit

Permalink
allow specifying LR schedule in terms of tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent 45ed078 commit 319fe5b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,15 @@ class SchedulerType(StrEnum):
constant = "constant"


class SchedulerUnits(StrEnum):
steps = "steps"
tokens = "tokens"


@dataclass
class SchedulerConfig(BaseConfig):
name: SchedulerType = SchedulerType.cosine_with_warmup
units: SchedulerUnits = SchedulerUnits.steps
t_warmup: int = 100
t_max: Optional[int] = None
alpha_f: float = 0.1
Expand Down
67 changes: 58 additions & 9 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
from .config import (
CheckpointType,
SchedulerUnits,
ShardedCheckpointerType,
SpeedMonitorConfig,
TrainConfig,
Expand Down Expand Up @@ -122,6 +123,14 @@ def dataset(self) -> IterableDataset:
assert isinstance(self.train_loader.dataset, IterableDataset)
return self.train_loader.dataset

@property
def tokens_per_batch(self) -> int:
return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length

@property
def batches_per_epoch(self) -> int:
return self.dataset.total_size // self.cfg.global_train_batch_size

@property
def max_epochs(self) -> int:
if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
Expand All @@ -138,20 +147,58 @@ def max_steps(self) -> int:
# convert to float *first* to handle scientific notation
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
tokens_remaining = max_tokens - self.global_train_tokens_seen
tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
steps_remaining = tokens_remaining // tokens_per_batch
steps_remaining = tokens_remaining // self.tokens_per_batch
return self.global_step + steps_remaining
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
examples_per_epoch = self.dataset.total_size
steps_per_epoch = examples_per_epoch // self.cfg.global_train_batch_size
return max_epochs * steps_per_epoch
return max_epochs * self.batches_per_epoch
else:
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration))
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

@property
def max_tokens(self) -> int:
if isinstance(self.cfg.max_duration, int):
return (
self.global_train_tokens_seen
+ min(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
)
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration[:-1].strip()))
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
return max_epochs * self.batches_per_epoch * self.tokens_per_batch
else:
# convert to float *first* to handle scientific notation
return (
self.global_train_tokens_seen
+ min(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
)
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

@property
def scheduler_current(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.global_step
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.global_train_tokens_seen
else:
raise NotImplementedError(self.cfg.scheduler.units)

@property
def scheduler_max(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.max_steps
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.max_tokens
else:
raise NotImplementedError(self.cfg.scheduler.units)

def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
Expand Down Expand Up @@ -233,7 +280,7 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
new_learning_rate = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.global_step, self.max_steps
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
for group in self.optim.param_groups:
group["lr"] = new_learning_rate
Expand Down Expand Up @@ -572,12 +619,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
# we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
# the corresponding values from `self.cfg`.
group["lr"] = self.scheduler.get_lr(self.cfg.optimizer.learning_rate, self.global_step, self.max_steps)
group["lr"] = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm, self.global_step, self.max_steps
self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm_ratio, self.global_step, self.max_steps
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
)

# Optimizer step.
Expand Down

0 comments on commit 319fe5b

Please sign in to comment.