From b4257ca645b8f85955da6b7f3bc0ecd523231336 Mon Sep 17 00:00:00 2001 From: rasbt Date: Tue, 9 Apr 2024 20:52:35 +0000 Subject: [PATCH] implement GaLoreArgs --- litgpt/args.py | 19 +++++++++++++++- litgpt/finetune/full.py | 45 +++++++++++++++----------------------- litgpt/finetune/lora.py | 48 ++++++++++++++++------------------------- 3 files changed, 53 insertions(+), 59 deletions(-) diff --git a/litgpt/args.py b/litgpt/args.py index d6ce527d36..0df4226b89 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -1,7 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional @dataclass @@ -61,3 +61,20 @@ class EvalArgs: """Number of tokens to generate""" max_iters: int = 100 """Number of iterations""" + +@dataclass +class GaLoreArgs: + """GaLore-related arguments""" + + use_galore: bool = False, + """Whether to enable GaLore (GaLore is applied to all linear layers).""" + galore_8bit: bool = False, + """Whether to use the 8-bit GaLore AdamW optimizer instead of the Galore AdamW optimizer.""" + galore_r: int = 128, + """GaLore rank""" + galore_update_proj_gap: int = 200, + """GaLore hyperparameter""" + galore_scale: float = 0.25, + """GaLore scale factor""" + galore_proj_type: Literal["std", "reverse_std"] = "std", + """GaLore projection type""" diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index d2ca67b98b..606443521d 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from torchmetrics import RunningMean -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, TrainArgs, GaLoreArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.model import GPT, Block, Config @@ -53,12 +53,13 @@ def setup( max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), - use_galore: bool = False, - galore_8bit: bool = False, - galore_r: int = 128, - galore_update_proj_gap: int = 200, - galore_scale: float = 0.25, - galore_proj_type: Literal["std", "reverse_std"] = "std", + galore: GaLoreArgs = GaLoreArgs( + galore_8bit=False, + galore_r=128, + galore_update_proj_gap=200, + galore_scale=0.25, + galore_proj_type="std", + ), logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, ) -> None: @@ -74,13 +75,7 @@ def setup( data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. - use_galore: Whether to enable GaLore (GaLore is applied to all linear layers). - use_galore_8bit: Whether to use the 8-bit GaLore AdamW optimizer - instead of the Galore AdamW optimizer. - galore_r: GaLore rank, - galore_update_proj_gap: GaLore hyperparameter, - galore_scale: GaLore scale factor, - galore_proj_type: GaLore projection type, + galore: GaLore-related arguments. See ``litgpt.args.GaLoreArgs`` for details. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ @@ -110,8 +105,7 @@ def setup( fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) fabric.launch( - main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, - use_galore, galore_8bit, galore_r, galore_update_proj_gap, galore_scale, galore_proj_type + main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, galore ) @@ -126,12 +120,7 @@ def main( out_dir: Path, train: TrainArgs, eval: EvalArgs, - use_galore: bool, - galore_8bit: bool, - galore_r: int, - galore_update_proj_gap: int, - galore_scale: float, - galore_proj_type: str, + galore: GaLoreArgs, ) -> None: validate_args(train, eval) @@ -153,7 +142,7 @@ def main( model = fabric.setup(model) - if use_galore: + if galore.use_galore: linear_params, nonlinear_params = get_linear_nonlinear_params(model) # Currently apply galore to all parameters; might add options to target specific layers later) @@ -161,13 +150,13 @@ def main( {'params': nonlinear_params}, { 'params': linear_params, - 'rank': galore_r, - 'update_proj_gap': galore_update_proj_gap, - 'scale': galore_scale, - 'proj_type': galore_proj_type + 'rank': galore.galore_r, + 'update_proj_gap': galore.galore_update_proj_gap, + 'scale': galore.galore_scale, + 'proj_type': galore.galore_proj_type } ] - if galore_8bit: + if galore.galore_8bit: from galore_torch import GaLoreAdamW8bit optimizer = GaLoreAdamW8bit( param_groups, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 3c8bf16197..1b4a57c591 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from torchmetrics import RunningMean -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, TrainArgs, GaLoreArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable @@ -53,12 +53,13 @@ def setup( lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, - use_galore: bool = False, - galore_8bit: bool = False, - galore_r: int = 128, - galore_update_proj_gap: int = 200, - galore_scale: float = 0.25, - galore_proj_type: Literal["std", "reverse_std"] = "std", + galore: GaLoreArgs = GaLoreArgs( + galore_8bit=False, + galore_r=128, + galore_update_proj_gap=200, + galore_scale=0.25, + galore_proj_type="std", + ), data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, @@ -91,13 +92,7 @@ def setup( lora_projection: Whether to apply LoRA to the output projection in the attention block. lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. lora_head: Whether to apply LoRA to output head in GPT. - use_galore: Whether to enable GaLore (GaLore is applied to all linear layers). - use_galore_8bit: Whether to use the 8-bit GaLore AdamW optimizer - instead of the Galore AdamW optimizer. - galore_r: GaLore rank, - galore_update_proj_gap: GaLore hyperparameter, - galore_scale: GaLore scale factor, - galore_proj_type: GaLore projection type, + galore: GaLore-related arguments. See ``litgpt.args.GaLoreArgs`` for details. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. @@ -152,12 +147,10 @@ def setup( fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) fabric.launch( - main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, - use_galore, galore_8bit, galore_r, galore_update_proj_gap, galore_scale, galore_proj_type + main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, galore, ) - def main( fabric: L.Fabric, devices: int, @@ -168,12 +161,7 @@ def main( out_dir: Path, train: TrainArgs, eval: EvalArgs, - use_galore: bool, - galore_8bit: bool, - galore_r: int, - galore_update_proj_gap: int, - galore_scale: float, - galore_proj_type: str, + galore: GaLoreArgs, ) -> None: validate_args(train, eval) @@ -202,12 +190,12 @@ def main( if isinstance(fabric.strategy.precision, BitsandbytesPrecision): import bitsandbytes as bnb - if use_galore: + if galore.use_galore: raise ValueError("The combinatiomn of QLoRA and GaLore is currently not supported.") optimizer_cls = bnb.optim.PagedAdamW - elif use_galore: + elif galore.use_galore: linear_params, nonlinear_params = get_linear_nonlinear_params(model) # Currently apply galore to all parameters; might add options to target specific layers later) @@ -215,13 +203,13 @@ def main( {'params': nonlinear_params}, { 'params': linear_params, - 'rank': galore_r, - 'update_proj_gap': galore_update_proj_gap, - 'scale': galore_scale, - 'proj_type': galore_proj_type + 'rank': galore.galore_r, + 'update_proj_gap': galore.galore_update_proj_gap, + 'scale': galore.galore_scale, + 'proj_type': galore.galore_proj_type } ] - if galore_8bit: + if galore.galore_8bit: from galore_torch import GaLoreAdamW8bit optimizer_cls = GaLoreAdamW8bit else: