From 206f04afb9c5a468db1a9d8e0a46cd10434c4ae5 Mon Sep 17 00:00:00 2001 From: Jai Date: Sun, 21 Apr 2024 21:26:29 +0100 Subject: [PATCH] got DDP + wandb logging working locally! --- ddp.py | 233 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 130 insertions(+), 103 deletions(-) diff --git a/ddp.py b/ddp.py index 841fa4a..4421caa 100644 --- a/ddp.py +++ b/ddp.py @@ -1,6 +1,6 @@ """Runs distributed training of NanoGPTs across multiple GPUs using PyTorch's DDP.""" -import argparse +import argparse # noqa: I001 import os import time from itertools import product @@ -23,17 +23,22 @@ LR_SET = [5e-2, 1e-3, 1e-4] # learning rate set OPTIM_SET = [Adam, AdamW, NAdam] # optimizer set ARCH_SET = [ # model architecture set - {"ctx_len": 1024, "emb_dim": 768, "n_heads": 12, "head_sz": 64}, - {"ctx_len": 2048, "emb_dim": 1024, "n_heads": 16, "head_sz": 64}, - {"ctx_len": 2048, "emb_dim": 1024, "n_heads": 20, "head_sz": 80}, + {"ctx_len": 256, "emb_dim": 256, "n_heads": 8, "head_sz": 32, "n_blocks": 8}, + {"ctx_len": 2048, "emb_dim": 1024, "n_heads": 16, "head_sz": 64, "n_blocks": 12}, + {"ctx_len": 2048, "emb_dim": 1024, "n_heads": 20, "head_sz": 80, "n_blocks": 12}, ] -def setup(rank, world_size, master_addr, master_port): +def setup( + rank: int, # rank of current process + world_size: int, # number of processes + master_addr: str, # master machine address (IP or hostname) + master_port: str, # master machine port +): """Sets up the DDP environment.""" os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port # Create distributed process group. - init_process_group(backend="nccl", rank=rank, world_size=world_size) + init_process_group(backend="gloo", rank=rank, world_size=world_size) def cleanup(): """Cleans up and kills DDP environment.""" @@ -48,31 +53,30 @@ def train( loss_fn: nn.modules.loss, # loss function rank: int, # rank of current process max_epochs: int = 5, # max n training epochs - max_batches: int = 1e9, # max n batches to train - val_chk_interval: int = 200, # check val loss every `val_chk_interval` batches and print losses + max_batches: int = 500, # max n batches to train + val_chk_interval: int = 200, # check val loss every `val_chk_interval` batches & print losses val_iter: int = 5, # number of batches on val_loader to run and avg when computing val loss patience_thresh: int = 1e9, # consecutive batches without val loss decrease for early stopping save_chkpt_dir: str = "", # dir to save model checkpoint - save_chkpt_thresh: float = 0.5, # save model checkpoint every `save_chkpt_interval` loss decrease + save_chkpt_thresh: float = 0.5, # save model chkpnt every `save_chkpt_interval` loss decrease ) -> tuple[torch.Tensor, np.ndarray, np.ndarray]: # -> loss, train_losses, val_losses """Trains a model, returns loss.""" - model = DDP(model, device_ids=[rank]) # = (val_iter - 1): break val_losses_avg.append(np.mean(val_losses[-val_iter:])) train_losses_avg.append(np.mean(train_losses[-val_chk_interval:])) model.train() - # /s> def apply_gradient_centralization(optimizer): """Applies gradient centralization to the optimizer. @@ -88,15 +92,17 @@ def apply_gradient_centralization(optimizer): ) # Centralize the gradient param.grad.data -= grad_mean - + # /s> # # [batch_sz, ctx_len, n_tokens], but... - # must reshape to compare against batch_sz vector of targets for cross-entropy loss. + # must reshape to compare against batch_sz vector of targets for cross-entropy loss loss = loss_fn(logits.view(-1, n_tokens), y_train.to(rank).view(-1)) loss.backward() apply_gradient_centralization(optimizer) @@ -119,18 +125,23 @@ def apply_gradient_centralization(optimizer): estimate_losses( model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg ) - wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) - # Patience check for early stopping. + if rank == 0: + wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) + # Return if patience check reached (early stopping). patience_ct = ( 0 if val_losses_avg[-1] < best_val_loss else patience_ct + val_chk_interval ) best_val_loss = min(best_val_loss, val_losses_avg[-1]) if patience_ct >= patience_thresh: - wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) + if rank == 0: + wandb.log( + {"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]} + ) return loss, train_losses_avg, val_losses_avg - # Max batch check. + # Return if max_batches reached. if (batch_i + 1) * (epoch + 1) >= max_batches: - wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) + if rank == 0: + wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) return loss, train_losses_avg, val_losses_avg # Save checkpoint check. if ( @@ -145,21 +156,29 @@ def apply_gradient_centralization(optimizer): init_loss = loss.item() # /ss> # /s> - # Finished. - wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) + # Return after max_epochs reached. + if rank == 0: + wandb.log( + { + "train_loss": train_losses_avg[-1], + "val_loss": val_losses_avg[-1], + "completed_batches": n_comp_batches, + "estimated_time_remaining": est_remaining_t + } + ) if Path(save_chkpt_dir).exists() and rank == 0: torch.save( model.module.state_dict(), @@ -168,19 +187,61 @@ def apply_gradient_centralization(optimizer): return loss, train_losses_avg, val_losses_avg def main( - rank, - world_size, - master_addr, - master_port, - model, - train_loader, - val_loader, - optimizer, - loss_fn, - save_chkpt_dir + rank: int, # rank of current process + world_size: int, # number of processes + master_addr: str, # master machine address (IP or hostname) + master_port: str, # master machine port + text_file: str, # path to text file to train on + train_config: tuple[float, optim.Optimizer, list[dict]], # lr, optimizer, model config ): + """Main function to run distributed training. + + Sets up DDP env, creates dataset from text file, creates and trains model, cleans up DDP env. + """ + # Set up DDP environment. setup(rank, world_size, master_addr, master_port) + # Set up dataset. + with open(text_file) as f: + text = f.read() + tokens = sorted(set(text)) + X, Y = build_dataset(text_file, ctx_len=train_config[2]["ctx_len"]) + dataset = TensorDataset(X, Y) + train_data, val_data = random_split(dataset, [0.9, 0.1]) + train_loader = DataLoader( + train_data, batch_size=32, shuffle=False, sampler=DistributedSampler(train_data) + ) + val_loader = DataLoader( + val_data, batch_size=32, shuffle=False, sampler=DistributedSampler(val_data) + ) + # Set up model. + model = NanoGPT(n_tokens=len(tokens), **train_config[2]) + model = DDP(model.to(rank), device_ids=[rank]) + # Initialize wandb config and run. + param_bytes = 4 # 32-bit floats + bytes_in_gb = 1024**3 + n_tot_params = sum(p.numel() for p in model.parameters()) + n_tot_params_b = round(n_tot_params / 1e9, 3) + tot_sz_gb = n_tot_params * param_bytes / bytes_in_gb + run_name = f"{train_config[1].__name__}-{train_config[0]}_{n_tot_params_b}B" + if rank == 0: + wandb_config = { + "n_params_bil": n_tot_params_b, + "sz_gb": tot_sz_gb, + "lr": train_config[0], + "optim": train_config[1], + "completed_batches": 0, + "expected_total_batches": None, # set in `train` function + "estimated_time_remaining": None, # set in `train` function + } + wandb_config.update(train_config[2]) + # name: -_; e.g. Adam-0.005_0.122B + wandb.init(project="NanoGPT-DDP", entity="jkbhagatio", name=run_name, config=wandb_config) + # Run training. + optimizer = train_config[1](model.parameters(), lr=train_config[0]) + loss_fn = nn.CrossEntropyLoss() + save_chkpt_dir = Path.home() / "nanogpt_ddp_runs" / "chkpts" / run_name train(model, train_loader, val_loader, optimizer, loss_fn, rank, save_chkpt_dir=save_chkpt_dir) + # Clean up DDP environment. cleanup() # Run training. @@ -188,74 +249,40 @@ def main( if __name__ == "__main__": # Parse args. parser = argparse.ArgumentParser(description="Run DDP distributed training of NanoGPTs.") - parser.add_argument("--config-idx", type=int, required=True, help="Index of config to run.") + parser.add_argument( + "--train-config-idx", + type=int, + required=True, + help="Index of train config to run. (See `train_configs` var)" + ) parser.add_argument( "--world-size", type=int, required=True, help="Number of processes to use for DDP." ) - parser.add_argument("--rank", type=int, required=True, help="Rank of current process.") + #parser.add_argument("--rank", type=int, required=True, help="Rank of current process.") parser.add_argument( "--master-addr", type=str, required=True, help="Master address (or hostname) for DDP." ) - parser.add_argument("--master-port", type=str, default="91827", help="Master port for DDP.") - args = parser.parse_args() - # Set config. - configs = list(product(LR_SET, OPTIM_SET, ARCH_SET)) - config = configs[args.config_idx] - # Set up dataset and model. - txtfile = Path.cwd() / "data/tiny_austen.txt" - with open(txtfile) as f: - text = f.read() - tokens = sorted(set(text)) - X, Y = build_dataset(txtfile, ctx_len=config[2]["ctx_len"]) - dataset = TensorDataset(X, Y) - train_data, val_data = random_split(dataset, splits=[0.9, 0.1]) - train_loader = DataLoader( - train_data, batch_size=32, shuffle=False, sampler=DistributedSampler(train_data) - ) - val_loader = DataLoader( - val_data, batch_size=32, shuffle=False, sampler=DistributedSampler(val_data) - ) - model = NanoGPT(n_tokens=len(tokens), **config[2]) - # Get model size. - param_bytes = 4 # 32-bit floats - bytes_in_gb = 1024 ** 3 - n_tot_params = sum(p.numel() for p in model.parameters()) - n_tot_params_b = round(n_tot_params / 1e9, 3) - tot_sz_gb = n_tot_params * param_bytes / bytes_in_gb - # Wandb config: model size, lr, optim, arch. - wandb_config = { - "n_params_bil": n_tot_params_b, "sz_gb": tot_sz_gb, "lr": config[0], "optim": config[1] - } - wandb_config.update(config[2]) - # name: -_; e.g. Adam-0.005_0.122B - run_name = f"{config[1].__name__}-{config[0]}_{n_tot_params_b}B" - wandb.init( - project="NanoGPT-DDP", - entity="jkbhagatio", - name=run_name, - config=wandb_config + parser.add_argument("--master-port", type=str, default="4444", help="Master port for DDP.") + parser.add_argument( + "--text-file", + type=str, + default=(Path.cwd() / "data/tiny_austen.txt"), + help="Path to text file to train on." ) - # Setup DDP environment. - setup(rank=args.rank, world_size=args.world_size, master_addr=args.master_addr, master_port=args.master_port) - # Spawn and run training. - optimizer = config[1](model.parameters(), lr=config[0]) - loss_fn = nn.CrossEntropyLoss() - mp.spawn( + args = parser.parse_args() + # Set training config. + train_configs = list(product(LR_SET, OPTIM_SET, ARCH_SET)) + train_config = train_configs[args.train_config_idx] + # Run DDP training. + mp.spawn( # passes `rank` to `main` as first arg automatically main, args=( - args.rank, args.world_size, args.master_addr, args.master_port, - model, - train_loader, - val_loader, - optimizer, - loss_fn, - run_name # save_chkpt_dir + args.text_file, + train_config, ), nprocs=args.world_size, - join=True + join=True, ) - # Cleanup DDP environment. - cleanup()