From 695dbd4091430b9a0318a868a68b8c95cdf7b182 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 22 May 2024 21:09:50 +0000 Subject: [PATCH] refactor instantiation --- litgpt/finetune/adapter.py | 22 ++++------------------ litgpt/finetune/adapter_v2.py | 21 ++++----------------- litgpt/finetune/full.py | 9 ++------- litgpt/finetune/lora.py | 21 ++++----------------- litgpt/pretrain.py | 12 +++--------- litgpt/utils.py | 24 ++++++++++++++++++++++++ 6 files changed, 41 insertions(+), 68 deletions(-) diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 9e73140274..3f1030a229 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -9,11 +9,9 @@ import lightning as L import torch -from lightning.pytorch.cli import instantiate_class from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor -from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import RunningMean @@ -29,9 +27,10 @@ choose_logger, chunked_cross_entropy, copy_config_files, - get_argument_names, get_default_supported_precision, init_out_dir, + instantiate_torch_optimizer, + instantiate_bnb_optimizer, load_checkpoint, num_parameters, parse_devices, @@ -151,22 +150,9 @@ def main( model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): - if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")): - raise ValueError("The chosen quantization format only supports the AdamW optimizer.") - - import bitsandbytes as bnb - if isinstance(optimizer, str): - optimizer = bnb.optim.PagedAdamW(model.parameters()) - else: - optim_args = get_argument_names(bnb.optim.PagedAdamW) - allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()} - optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs) + optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) else: - if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) - optimizer = optimizer_cls(model.parameters()) - else: - optimizer = instantiate_class(model.parameters(), optimizer) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index 777905d447..785668939e 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -9,7 +9,6 @@ import lightning as L import torch -from lightning.pytorch.cli import instantiate_class from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor @@ -28,9 +27,10 @@ choose_logger, chunked_cross_entropy, copy_config_files, - get_argument_names, get_default_supported_precision, init_out_dir, + instantiate_torch_optimizer, + instantiate_bnb_optimizer, load_checkpoint, num_parameters, parse_devices, @@ -150,22 +150,9 @@ def main( model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): - if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")): - raise ValueError("The chosen quantization format only supports the AdamW optimizer.") - - import bitsandbytes as bnb - if isinstance(optimizer, str): - optimizer = bnb.optim.PagedAdamW(model.parameters()) - else: - optim_args = get_argument_names(bnb.optim.PagedAdamW) - allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()} - optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs) + optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) else: - if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) - optimizer = optimizer_cls(model.parameters()) - else: - optimizer = instantiate_class(model.parameters(), optimizer) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 4b162c232b..cf32ae501d 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -9,7 +9,6 @@ import lightning as L import torch -from lightning.pytorch.cli import instantiate_class from lightning.fabric.strategies import FSDPStrategy from torch.utils.data import DataLoader from torchmetrics import RunningMean @@ -29,6 +28,7 @@ get_default_supported_precision, load_checkpoint, init_out_dir, + instantiate_torch_optimizer, num_parameters, parse_devices, save_hyperparameters, @@ -136,12 +136,7 @@ def main( model = fabric.setup(model) - if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) - optimizer = optimizer_cls(model.parameters()) - else: - optimizer = instantiate_class(model.parameters(), optimizer) - + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 37ccc100f4..5f5e12dcf9 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -9,7 +9,6 @@ import lightning as L import torch -from lightning.pytorch.cli import instantiate_class from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor @@ -29,10 +28,11 @@ choose_logger, chunked_cross_entropy, copy_config_files, - get_argument_names, get_default_supported_precision, load_checkpoint, init_out_dir, + instantiate_torch_optimizer, + instantiate_bnb_optimizer, num_parameters, parse_devices, save_hyperparameters, @@ -180,22 +180,9 @@ def main( model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): - if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")): - raise ValueError("The chosen quantization format only supports the AdamW optimizer.") - - import bitsandbytes as bnb - if isinstance(optimizer, str): - optimizer = bnb.optim.PagedAdamW(model.parameters()) - else: - optim_args = get_argument_names(bnb.optim.PagedAdamW) - allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()} - optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs) + optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) else: - if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) - optimizer = optimizer_cls(model.parameters()) - else: - optimizer = instantiate_class(model.parameters(), optimizer) + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 2480c6ab6b..8bc36239eb 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -11,7 +11,6 @@ import lightning as L import torch import torch.nn as nn -from lightning.pytorch.cli import instantiate_class from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops from torch.utils.data import DataLoader @@ -24,7 +23,6 @@ from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP from litgpt.utils import ( - CLI, CycleIterator, capture_hparams, choose_logger, @@ -32,6 +30,7 @@ copy_config_files, get_default_supported_precision, init_out_dir, + instantiate_torch_optimizer, num_parameters, parse_devices, reset_parameters, @@ -175,13 +174,8 @@ def main( model = torch.compile(model) model = fabric.setup(model) - if isinstance(optimizer, str): - optimizer_cls = getattr(torch.optim, optimizer) - optimizer = optimizer_cls(model.parameters(), fused=fabric.device.type == "cuda") - else: - #optimizer["fused"] = fabric.device.type == "cuda" - optimizer = instantiate_class(model.parameters(), optimizer) - + # TODO: add fused=fabric.device.type == "cuda" + optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length) diff --git a/litgpt/utils.py b/litgpt/utils.py index 517bcf7efe..5629871937 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -21,6 +21,7 @@ from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.cli import instantiate_class from torch.serialization import normalize_storage_type from typing_extensions import Self @@ -492,3 +493,26 @@ def get_argument_names(cls): sig = inspect.signature(cls.__init__) return {name for name, param in sig.parameters.items() if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]} + + +def instantiate_bnb_optimizer(optimizer, model_parameters): + if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")): + raise ValueError("The chosen quantization format only supports the AdamW optimizer.") + + import bitsandbytes as bnb + if isinstance(optimizer, str): + optimizer = bnb.optim.PagedAdamW(model_parameters) + else: + optim_args = get_argument_names(bnb.optim.PagedAdamW) + allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()} + optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs) + return optimizer + + +def instantiate_torch_optimizer(optimizer, model_parameters): + if isinstance(optimizer, str): + optimizer_cls = getattr(torch.optim, optimizer) + optimizer = optimizer_cls(model_parameters) + else: + optimizer = instantiate_class(model_parameters, optimizer) + return optimizer