Skip to content

Commit

Permalink
Added Resume Functionality for Full Finetuning Script (full.py) (#788)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
windprak and awaelchli authored Jan 19, 2024
1 parent 0f021f3 commit 4bd370c
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import lightning as L
import torch
Expand Down Expand Up @@ -56,6 +56,7 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/full/alpaca"),
precision: Optional[str] = None,
resume: Union[bool, Path] = False,
) -> None:
precision = precision or get_default_supported_precision(training=True)

Expand All @@ -74,12 +75,11 @@ def setup(
logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
fabric.launch(main, data_dir, checkpoint_dir, out_dir, resume=resume)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, resume: Union[bool, Path]) -> None:
check_valid_checkpoint_dir(checkpoint_dir)

fabric.seed_everything(1337) # same seed for every process to init model (FSDP)

if fabric.global_rank == 0:
Expand All @@ -99,31 +99,41 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path)
model = fabric.setup_module(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = fabric.setup_optimizers(optimizer)

load_checkpoint(fabric, model, checkpoint_path)
state = {
"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0, "total_lengths": 0
}

if resume is True:
resume = max(out_dir.glob("*.pth"), key=(lambda p: int(p.name.split("-")[1])))
if resume:
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
else:
load_checkpoint(fabric, state["model"], checkpoint_path)

fabric.seed_everything(1337 + fabric.global_rank)

train_time = time.perf_counter()
train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
train(fabric, state, train_data, val_data, checkpoint_dir, out_dir)
fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

# Save the final checkpoint at the end of training
save_path = out_dir / "lit_model_finetuned.pth"
save_checkpoint(fabric, model, save_path)
save_checkpoint(fabric, {"model": state["model"]}, save_path)


def train(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
state: Dict,
train_data: List[Dict],
val_data: List[Dict],
checkpoint_dir: Path,
out_dir: Path,
) -> None:
model = state["model"]
optimizer = state["optimizer"]
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
model.max_seq_length = min(longest_seq_length, max_seq_length or float("inf"))
Expand All @@ -135,22 +145,20 @@ def train(
validate(fabric, model, val_data, tokenizer, max_iters=2) # sanity check

throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
total_t0 = time.perf_counter()

for iter_num in range(1, max_iters + 1):
if step_count <= warmup_steps:
for state["iter_num"] in range(state["iter_num"] + 1, max_iters + 1):
if state["step_count"] <= warmup_steps:
# linear warmup
lr = learning_rate * step_count / warmup_steps
lr = learning_rate * state["step_count"] / warmup_steps
for param_group in optimizer.param_groups:
param_group["lr"] = lr

iter_t0 = time.perf_counter()

input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if iter_num == 1 else None)
input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if state["iter_num"] == 1 else None)

is_accumulating = iter_num % gradient_accumulation_iters != 0
is_accumulating = state["iter_num"] % gradient_accumulation_iters != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
# shift the targets such that output n predicts token n+1
Expand All @@ -160,30 +168,31 @@ def train(
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
step_count += 1
state["step_count"] += 1

total_lengths += input_ids.numel()
if iter_num % log_interval == 0:
state['total_lengths'] += input_ids.numel()
if state["iter_num"] % log_interval == 0:
loss_item = loss.item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * micro_batch_size, lengths=total_lengths
time=t1 - total_t0, batches=state["iter_num"], samples=state["iter_num"] * micro_batch_size,
lengths=state['total_lengths']
)
throughput.compute_and_log(step=iter_num)
throughput.compute_and_log(step=state["iter_num"])
fabric.print(
f"iter {iter_num} step {step_count}: loss {loss_item:.4f}, iter time:"
f"iter {state['iter_num']} step {state['step_count']}: loss {loss_item:.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
)

if not is_accumulating and step_count % eval_interval == 0:
if not is_accumulating and state['step_count'] % eval_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, max_iters=eval_iters)
t1 = time.perf_counter() - t0
fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
fabric.print(f"step {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
fabric.barrier()
if not is_accumulating and step_count % save_interval == 0:
checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
save_checkpoint(fabric, model, checkpoint_path)
if not is_accumulating and state['step_count'] % save_interval == 0:
checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
save_checkpoint(fabric, state, checkpoint_path)


# FSDP has issues with `inference_mode`
Expand Down Expand Up @@ -260,9 +269,9 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def save_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
def save_checkpoint(fabric, state, file_path: Path):
fabric.print(f"Saving weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model})
fabric.save(file_path, state)


if __name__ == "__main__":
Expand Down

0 comments on commit 4bd370c

Please sign in to comment.