diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py index 481f2fa367..81961eb433 100644 --- a/pretrain/tinyllama.py +++ b/pretrain/tinyllama.py @@ -107,11 +107,11 @@ def main(fabric, resume): optimizer = fabric.setup_optimizers(optimizer) state = { - "model": model, - "optimizer": optimizer, + "model": model, + "optimizer": optimizer, "train_dataloader": train_dataloader, - "hparams": hparams, - "iter_num": 0, + "hparams": hparams, + "iter_num": 0, "step_count": 0, } @@ -166,8 +166,8 @@ def train(fabric, state, train_dataloader, val_dataloader, resume): state["iter_num"] += 1 iter_t0 = time.perf_counter() - input_ids = train_data[:, 0:model.config.block_size].contiguous().long() - targets = train_data[:, 1:(model.config.block_size + 1)].contiguous().long() + input_ids = train_data[:, 0 : model.config.block_size].contiguous().long() + targets = train_data[:, 1 : (model.config.block_size + 1)].contiguous().long() is_accumulating = state["iter_num"] % gradient_accumulation_iters != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): @@ -254,7 +254,7 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -> Tuple[DataLoader, DataLoader]: - from lightning.data import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader + from lightning.data import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset from lightning.data.streaming.item_loader import TokensLoader # Increase by one because we need the next word as well @@ -289,7 +289,9 @@ def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) - # Consider setting to False, but we would lose some samples due to truncation when world size > 1 drop_last=True, ) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True) + val_dataloader = DataLoader( + val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True + ) return train_dataloader, val_dataloader