Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor usability improvements to tinyllama pretraining script #749

Merged
merged 8 commits into from
Nov 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt.model import GPT, Block, Config, LLaMAMLP
from lit_gpt.model import GPT, Block, Config, LLaMAMLP, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset
from lit_gpt.utils import chunked_cross_entropy, num_parameters

Expand All @@ -38,9 +38,9 @@
global_batch_size = 512
learning_rate = 4e-4
micro_batch_size = 8
max_steps = 715256 * 2
max_tokens = int(3e12) # 3 trillion
warmup_steps = 2000
log_step_interval = 2
log_step_interval = 1
eval_iters = 100
save_step_interval = 1000
eval_step_interval = 1000
Expand All @@ -56,9 +56,6 @@
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
warmup_iters = warmup_steps * gradient_accumulation_steps

max_iters = max_steps * gradient_accumulation_steps
lr_decay_iters = max_iters
log_iter_interval = log_step_interval * gradient_accumulation_steps


Expand All @@ -67,9 +64,9 @@

def setup(resume: Union[bool, Path] = False):
if use_wandb:
logger = WandbLogger(project="tinyllama", name="training", resume=(resume is not False))
logger = WandbLogger(project="tinyllama", name=name, resume=(resume is not False))
else:
logger = CSVLogger(root_dir="logs", name="tinyllama")
logger = CSVLogger(root_dir="logs", name=name)

if devices > 1:
strategy = FSDPStrategy(
Expand All @@ -78,6 +75,7 @@ def setup(resume: Union[bool, Path] = False):
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
sharding_strategy="HYBRID_SHARD",
)
else:
strategy = "auto"
Expand Down Expand Up @@ -107,7 +105,7 @@ def main(fabric, resume):
t0 = time.perf_counter()
with fabric.init_module(empty_init=False):
model = GPT(config)
model.apply(partial(init_weights, n_layer=config.n_layer))
model.apply(partial(init_weights, n_layer=config.n_layer, n_embd=config.n_embd))

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters: {num_parameters(model):,}")
Expand Down Expand Up @@ -150,11 +148,19 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

max_tokens_per_device = max_tokens // fabric.world_size
tokens_per_iter = micro_batch_size * model.config.block_size
carmocca marked this conversation as resolved.
Show resolved Hide resolved
max_iters = max_tokens_per_device // tokens_per_iter

total_t0 = time.perf_counter()
initial_iter = state["iter_num"]
curr_iter = 0

for train_data in train_dataloader:

if state["iter_num"] >= max_iters:
break

# resume data loader state by fast-forwarding through all seen batches
# drop this once streaming dataset supports proper resuming
if resume:
Expand All @@ -170,11 +176,8 @@ def train(fabric, state, train_dataloader, val_dataloader, resume):
f"Took {time.perf_counter() - total_t0:.1f} seconds to reach iteration {initial_iter}."
)

if state["iter_num"] >= max_iters:
break

# determine and set the learning rate for this iteration
lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
lr = get_lr(state["iter_num"], max_iters) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -303,7 +306,7 @@ def create_dataloaders(batch_size: int, block_size: int) -> Tuple[DataLoader, Da


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(it: int, lr_decay_iters: int) -> int:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand All @@ -317,17 +320,17 @@ def get_lr(it):
return min_lr + coeff * (learning_rate - min_lr)


def init_weights(module: nn.Module, n_layer: int):
def init_weights(module: nn.Module, n_layer: int, n_embd: int):
# Follows GPT-NeoX: https://arxiv.org/abs/2204.06745
if isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.shape[1]))
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.shape[1]))
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
if module.bias is not None:
nn.init.zeros_(module.bias)
for name, param in module.named_parameters():
if name == "proj.weight" and isinstance(module, LLaMAMLP):
nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(param.shape[-1]) / n_layer))
if name == "proj.weight" and isinstance(module, (LLaMAMLP, CausalSelfAttention)):
nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(n_embd) / n_layer))


if __name__ == "__main__":
Expand Down