Skip to content

Commit

Permalink
Minor usability improvements to tinyllama pretraining script (#749)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Nov 21, 2023
1 parent a2400f9 commit 21c1c59
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt.model import GPT, Block, Config, LLaMAMLP
from lit_gpt.model import GPT, Block, Config, LLaMAMLP, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset
from lit_gpt.utils import chunked_cross_entropy, num_parameters

Expand All @@ -38,9 +38,9 @@
global_batch_size = 512
learning_rate = 4e-4
micro_batch_size = 8
max_steps = 715256 * 2
max_tokens = int(3e12) # 3 trillion
warmup_steps = 2000
log_step_interval = 2
log_step_interval = 1
eval_iters = 100
save_step_interval = 1000
eval_step_interval = 1000
Expand All @@ -56,9 +56,6 @@
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
warmup_iters = warmup_steps * gradient_accumulation_steps

max_iters = max_steps * gradient_accumulation_steps
lr_decay_iters = max_iters
log_iter_interval = log_step_interval * gradient_accumulation_steps


Expand All @@ -67,9 +64,9 @@

def setup(resume: Union[bool, Path] = False):
if use_wandb:
logger = WandbLogger(project="tinyllama", name="training", resume=(resume is not False))
logger = WandbLogger(project="tinyllama", name=name, resume=(resume is not False))
else:
logger = CSVLogger(root_dir="logs", name="tinyllama")
logger = CSVLogger(root_dir="logs", name=name)

if devices > 1:
strategy = FSDPStrategy(
Expand All @@ -78,6 +75,7 @@ def setup(resume: Union[bool, Path] = False):
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
sharding_strategy="HYBRID_SHARD",
)
else:
strategy = "auto"
Expand Down Expand Up @@ -107,7 +105,7 @@ def main(fabric, resume):
t0 = time.perf_counter()
with fabric.init_module(empty_init=False):
model = GPT(config)
model.apply(partial(init_weights, n_layer=config.n_layer))
model.apply(partial(init_weights, n_layer=config.n_layer, n_embd=config.n_embd))

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters: {num_parameters(model):,}")
Expand Down Expand Up @@ -150,11 +148,19 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

max_tokens_per_device = max_tokens // fabric.world_size
tokens_per_iter = micro_batch_size * model.config.block_size
max_iters = max_tokens_per_device // tokens_per_iter

total_t0 = time.perf_counter()
initial_iter = state["iter_num"]
curr_iter = 0

for train_data in train_dataloader:

if state["iter_num"] >= max_iters:
break

# resume data loader state by fast-forwarding through all seen batches
# drop this once streaming dataset supports proper resuming
if resume:
Expand All @@ -169,11 +175,8 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
f"Took {time.perf_counter() - total_t0:.1f} seconds to reach iteration {initial_iter}."
)

if state["iter_num"] >= max_iters:
break

# determine and set the learning rate for this iteration
lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
lr = get_lr(state["iter_num"], max_iters) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -302,7 +305,7 @@ def create_dataloaders(batch_size: int, block_size: int) -> Tuple[DataLoader, Da


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(it: int, lr_decay_iters: int) -> int:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand All @@ -316,17 +319,17 @@ def get_lr(it):
return min_lr + coeff * (learning_rate - min_lr)


def init_weights(module: nn.Module, n_layer: int):
def init_weights(module: nn.Module, n_layer: int, n_embd: int):
# Follows GPT-NeoX: https://arxiv.org/abs/2204.06745
if isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.shape[1]))
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.shape[1]))
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
if module.bias is not None:
nn.init.zeros_(module.bias)
for name, param in module.named_parameters():
if name == "proj.weight" and isinstance(module, LLaMAMLP):
nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(param.shape[-1]) / n_layer))
if name == "proj.weight" and isinstance(module, (LLaMAMLP, CausalSelfAttention)):
nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(n_embd) / n_layer))


if __name__ == "__main__":
Expand Down

0 comments on commit 21c1c59

Please sign in to comment.