Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to restart on new epoch #383

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,11 @@ class TrainConfig(BaseConfig):
Used to seed all initial RNG states.
"""

epoch: int = 0
"""
Increment this when starting a new epoch.
"""

dry_run: bool = False
"""
If ``True``, don't actually train.
Expand Down
2 changes: 1 addition & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
IterableDataset(
dataset, # type: ignore
train_config.global_train_batch_size,
seed=train_config.seed,
seed=train_config.seed + train_config.epoch,
shuffle=True,
drop_last=train_config.data.drop_last,
max_examples=train_config.global_train_batch_size * train_config.max_duration,
Expand Down
64 changes: 33 additions & 31 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,11 @@ class Trainer:
train_loader: DataLoader
device: torch.device
evaluators: List[Evaluator]
epoch: int = 0
global_step: int = 0
global_data_step: int = 0
"""This is now redundant since adding 'global_train_examples_seen'."""
global_train_examples_seen: int = 0
"""Tracks the global number of training examples seen for the purpose of restoring the dataset
position on restarts."""
global_train_examples_seen_this_epoch: int = 0
"""Tracks the global number of training examples seen in the current epoch for the purpose of restoring
the data loader position on restarts."""
global_train_tokens_seen: int = 0
"""Tracks the global total number of tokens trained on."""
checkpoints: List[Path] = field(default_factory=list)
Expand All @@ -118,9 +117,9 @@ class Trainer:

def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
"global_step": self.global_step,
"global_data_step": self.global_data_step,
"global_train_examples_seen": self.global_train_examples_seen,
"global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
"global_train_tokens_seen": self.global_train_tokens_seen,
"world_size": get_world_size(),
"checkpoints": self.checkpoints,
Expand All @@ -147,40 +146,44 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
]

# Dataset / dataloader position.
checkpoint_epoch = state_dict.get("epoch", 0)
self.global_step = state_dict["global_step"]
self.global_data_step = state_dict["global_data_step"]
self.global_train_examples_seen = state_dict.get( # newer addition
"global_train_examples_seen", self.global_data_step * self.cfg.global_train_batch_size
self.global_train_examples_seen_this_epoch = state_dict.get(
"global_train_examples_seen_this_epoch",
state_dict.get( # for backwards compatibility
"global_train_examples_seen",
state_dict.get("global_data_step", 0) * self.cfg.global_train_batch_size,
),
)
self.global_train_tokens_seen = state_dict.get( # newer addition
self.global_train_tokens_seen = state_dict.get(
"global_train_tokens_seen",
self.global_data_step * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length,
state_dict.get("global_data_step", 0) # for backwards compatibility
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in throughput/total_tokens being reset to 0 if global_train_tokens_seen and global_data_step are both not present. Maybe state_dict.get("global_data_step", self.global_step) is safer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, good catch: b8ca94d

* self.cfg.global_train_batch_size
* self.cfg.model.max_sequence_length,
)

if not self.cfg.restore_dataloader:
self.global_data_step = 0
self.global_train_examples_seen = 0
self.epoch = 0
self.global_train_tokens_seen = 0
elif self.cfg.fast_forward_batches:
self.global_data_step += self.cfg.fast_forward_batches
self.global_train_examples_seen_this_epoch = 0
elif checkpoint_epoch != self.epoch:
log.info(f"Starting new epoch (epoch = {self.epoch})")
self.global_train_examples_seen_this_epoch = 0

if self.cfg.fast_forward_batches:
log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
# Technically we don't "see" these batches that we fast-forward through, but we use
# this variable to update the position of the dataset so we need to include them here.
self.global_train_examples_seen += self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
self.global_train_examples_seen_this_epoch += (
self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
)
# NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
# that variable is meant to track the actual number of tokens trained on.

if self.global_data_step > 0:
if self.global_data_step > self.global_step:
log.info(
f"Fast-forwarding data loader to step {self.global_step:,d}+{self.global_data_step-self.global_step:,d} "
f"({self.global_train_examples_seen:,d} examples)"
)
else:
log.info(
f"Fast-forwarding data loader to step {self.global_data_step:,d} "
f"({self.global_train_examples_seen:,d} examples)"
)
if self.global_train_examples_seen_this_epoch > 0:
assert isinstance(self.train_loader.dataset, IterableDataset)
self.train_loader.dataset.start_index = self.global_train_examples_seen
log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
self.train_loader.dataset.start_index = self.global_train_examples_seen_this_epoch

# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
Expand Down Expand Up @@ -789,8 +792,7 @@ def on_trace_ready(p):
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_data_step += 1
self.global_train_examples_seen += global_batch_size
self.global_train_examples_seen_this_epoch += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
Expand Down
1 change: 1 addition & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
# Consolidate components into `Trainer` object.
with Trainer(
cfg=cfg,
epoch=cfg.epoch,
model=olmo_model,
fsdp_model=fsdp_model,
optim=optim,
Expand Down
Loading