From 4bd370c7cb67caccf748bd4cc5a666b46c892442 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Fri, 19 Jan 2024 17:26:59 +0100 Subject: [PATCH] Added Resume Functionality for Full Finetuning Script (full.py) (#788) Co-authored-by: awaelchli --- finetune/full.py | 71 +++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/finetune/full.py b/finetune/full.py index 28d2e765bd..c0ed41528f 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -4,7 +4,7 @@ import sys import time from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import lightning as L import torch @@ -56,6 +56,7 @@ def setup( checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), out_dir: Path = Path("out/full/alpaca"), precision: Optional[str] = None, + resume: Union[bool, Path] = False, ) -> None: precision = precision or get_default_supported_precision(training=True) @@ -74,12 +75,11 @@ def setup( logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger) fabric.print(hparams) - fabric.launch(main, data_dir, checkpoint_dir, out_dir) + fabric.launch(main, data_dir, checkpoint_dir, out_dir, resume=resume) -def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None: +def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, resume: Union[bool, Path]) -> None: check_valid_checkpoint_dir(checkpoint_dir) - fabric.seed_everything(1337) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: @@ -99,31 +99,41 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) model = fabric.setup_module(model) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) optimizer = fabric.setup_optimizers(optimizer) - - load_checkpoint(fabric, model, checkpoint_path) + state = { + "model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0, "total_lengths": 0 + } + + if resume is True: + resume = max(out_dir.glob("*.pth"), key=(lambda p: int(p.name.split("-")[1]))) + if resume: + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + else: + load_checkpoint(fabric, state["model"], checkpoint_path) fabric.seed_everything(1337 + fabric.global_rank) train_time = time.perf_counter() - train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir) - fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + train(fabric, state, train_data, val_data, checkpoint_dir, out_dir) + fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") # Save the final checkpoint at the end of training save_path = out_dir / "lit_model_finetuned.pth" - save_checkpoint(fabric, model, save_path) + save_checkpoint(fabric, {"model": state["model"]}, save_path) def train( fabric: L.Fabric, - model: GPT, - optimizer: torch.optim.Optimizer, + state: Dict, train_data: List[Dict], val_data: List[Dict], checkpoint_dir: Path, out_dir: Path, ) -> None: + model = state["model"] + optimizer = state["optimizer"] tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) model.max_seq_length = min(longest_seq_length, max_seq_length or float("inf")) @@ -135,22 +145,20 @@ def train( validate(fabric, model, val_data, tokenizer, max_iters=2) # sanity check throughput = ThroughputMonitor(fabric, window_size=50) - step_count = 0 - total_lengths = 0 total_t0 = time.perf_counter() - for iter_num in range(1, max_iters + 1): - if step_count <= warmup_steps: + for state["iter_num"] in range(state["iter_num"] + 1, max_iters + 1): + if state["step_count"] <= warmup_steps: # linear warmup - lr = learning_rate * step_count / warmup_steps + lr = learning_rate * state["step_count"] / warmup_steps for param_group in optimizer.param_groups: param_group["lr"] = lr iter_t0 = time.perf_counter() - input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if iter_num == 1 else None) + input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if state["iter_num"] == 1 else None) - is_accumulating = iter_num % gradient_accumulation_iters != 0 + is_accumulating = state["iter_num"] % gradient_accumulation_iters != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) # shift the targets such that output n predicts token n+1 @@ -160,30 +168,31 @@ def train( if not is_accumulating: optimizer.step() optimizer.zero_grad() - step_count += 1 + state["step_count"] += 1 - total_lengths += input_ids.numel() - if iter_num % log_interval == 0: + state['total_lengths'] += input_ids.numel() + if state["iter_num"] % log_interval == 0: loss_item = loss.item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( - time=t1 - total_t0, batches=iter_num, samples=iter_num * micro_batch_size, lengths=total_lengths + time=t1 - total_t0, batches=state["iter_num"], samples=state["iter_num"] * micro_batch_size, + lengths=state['total_lengths'] ) - throughput.compute_and_log(step=iter_num) + throughput.compute_and_log(step=state["iter_num"]) fabric.print( - f"iter {iter_num} step {step_count}: loss {loss_item:.4f}, iter time:" + f"iter {state['iter_num']} step {state['step_count']}: loss {loss_item:.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" ) - if not is_accumulating and step_count % eval_interval == 0: + if not is_accumulating and state['step_count'] % eval_interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_data, tokenizer, max_iters=eval_iters) t1 = time.perf_counter() - t0 - fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms") + fabric.print(f"step {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms") fabric.barrier() - if not is_accumulating and step_count % save_interval == 0: - checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth" - save_checkpoint(fabric, model, checkpoint_path) + if not is_accumulating and state['step_count'] % save_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + save_checkpoint(fabric, state, checkpoint_path) # FSDP has issues with `inference_mode` @@ -260,9 +269,9 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: return longest_seq_length, longest_seq_ix -def save_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: +def save_checkpoint(fabric, state, file_path: Path): fabric.print(f"Saving weights to {str(file_path)!r}") - fabric.save(file_path, {"model": model}) + fabric.save(file_path, state) if __name__ == "__main__":