Skip to content

Commit

Permalink
Automated formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 31, 2024
1 parent 3a0a7c2 commit cca2986
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def main(fabric, resume):
optimizer = fabric.setup_optimizers(optimizer)

state = {
"model": model,
"optimizer": optimizer,
"model": model,
"optimizer": optimizer,
"train_dataloader": train_dataloader,
"hparams": hparams,
"iter_num": 0,
"hparams": hparams,
"iter_num": 0,
"step_count": 0,
}

Expand Down Expand Up @@ -166,8 +166,8 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
state["iter_num"] += 1
iter_t0 = time.perf_counter()

input_ids = train_data[:, 0:model.config.block_size].contiguous().long()
targets = train_data[:, 1:(model.config.block_size + 1)].contiguous().long()
input_ids = train_data[:, 0 : model.config.block_size].contiguous().long()
targets = train_data[:, 1 : (model.config.block_size + 1)].contiguous().long()

is_accumulating = state["iter_num"] % gradient_accumulation_iters != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
Expand Down Expand Up @@ -254,7 +254,7 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max


def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -> Tuple[DataLoader, DataLoader]:
from lightning.data import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
from lightning.data import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

# Increase by one because we need the next word as well
Expand Down Expand Up @@ -289,7 +289,9 @@ def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True
)
return train_dataloader, val_dataloader


Expand Down

0 comments on commit cca2986

Please sign in to comment.