From 379e54e4cccff2105d2811c1eca9effde44968cf Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 22 May 2024 22:10:45 +0000 Subject: [PATCH] fix thunder test --- extensions/thunder/pretrain.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index 6aa77a745f..93dea65927 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -8,7 +8,7 @@ from datetime import timedelta from functools import partial from pathlib import Path -from typing import Any, Callable, Optional, Tuple, Union, List +from typing import Any, Callable, Optional, Tuple, Union, List, Dict import lightning as L import torch @@ -30,6 +30,7 @@ choose_logger, chunked_cross_entropy, copy_config_files, + instantiate_torch_optimizer, num_parameters, parse_devices, reset_parameters, @@ -55,16 +56,13 @@ def setup( global_batch_size=512, micro_batch_size=4, max_tokens=int(3e12), # 3 trillion - learning_rate=4e-4, - weight_decay=1e-1, - beta1=0.9, - beta2=0.95, max_norm=1.0, min_lr=4e-5, lr_warmup_steps=2000, tie_embeddings=False, ), eval: EvalArgs = EvalArgs(interval=1000, max_iters=100), + optimizer: Union[str, Dict] = "AdamW", devices: Union[int, str] = "auto", tokenizer_dir: Optional[Path] = None, logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard", @@ -93,6 +91,7 @@ def setup( tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data module require this. logger_name: The name of the logger to send metrics to. + optimizer: An optimizer name (such as "AdamW") or config. seed: The random seed to use for reproducibility. compiler: If desired, the compiler/JIT to use. executors: If using Thunder, the executors to enable. @@ -157,6 +156,7 @@ def setup( tokenizer, train, eval, + optimizer, compiler, ) @@ -174,6 +174,7 @@ def main( tokenizer: Optional[Tokenizer], train: TrainArgs, eval: EvalArgs, + optimizer: Union[str, Dict], compiler: Optional[Literal["thunder", "torch"]], ) -> None: validate_args(train, eval, initial_checkpoint_dir, resume) @@ -201,13 +202,7 @@ def main( if compiler == "thunder": # avoid `Tensor.register_hook` which is unsupported model._register_backward_hook = lambda *_: None - optimizer = torch.optim.AdamW( - model.parameters(), - lr=train.learning_rate, - weight_decay=train.weight_decay, - betas=(train.beta1, train.beta2), - fused=True, - ) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)