Skip to content

Commit

Permalink
implement GaLoreArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 9, 2024
1 parent 8a47d07 commit b4257ca
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 59 deletions.
19 changes: 18 additions & 1 deletion litgpt/args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional


@dataclass
Expand Down Expand Up @@ -61,3 +61,20 @@ class EvalArgs:
"""Number of tokens to generate"""
max_iters: int = 100
"""Number of iterations"""

@dataclass
class GaLoreArgs:
"""GaLore-related arguments"""

use_galore: bool = False,
"""Whether to enable GaLore (GaLore is applied to all linear layers)."""
galore_8bit: bool = False,
"""Whether to use the 8-bit GaLore AdamW optimizer instead of the Galore AdamW optimizer."""
galore_r: int = 128,
"""GaLore rank"""
galore_update_proj_gap: int = 200,
"""GaLore hyperparameter"""
galore_scale: float = 0.25,
"""GaLore scale factor"""
galore_proj_type: Literal["std", "reverse_std"] = "std",
"""GaLore projection type"""
45 changes: 17 additions & 28 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader
from torchmetrics import RunningMean

from litgpt.args import EvalArgs, TrainArgs
from litgpt.args import EvalArgs, TrainArgs, GaLoreArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.model import GPT, Block, Config
Expand Down Expand Up @@ -53,12 +53,13 @@ def setup(
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
use_galore: bool = False,
galore_8bit: bool = False,
galore_r: int = 128,
galore_update_proj_gap: int = 200,
galore_scale: float = 0.25,
galore_proj_type: Literal["std", "reverse_std"] = "std",
galore: GaLoreArgs = GaLoreArgs(
galore_8bit=False,
galore_r=128,
galore_update_proj_gap=200,
galore_scale=0.25,
galore_proj_type="std",
),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
) -> None:
Expand All @@ -74,13 +75,7 @@ def setup(
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
use_galore: Whether to enable GaLore (GaLore is applied to all linear layers).
use_galore_8bit: Whether to use the 8-bit GaLore AdamW optimizer
instead of the Galore AdamW optimizer.
galore_r: GaLore rank,
galore_update_proj_gap: GaLore hyperparameter,
galore_scale: GaLore scale factor,
galore_proj_type: GaLore projection type,
galore: GaLore-related arguments. See ``litgpt.args.GaLoreArgs`` for details.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
Expand Down Expand Up @@ -110,8 +105,7 @@ def setup(

fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger)
fabric.launch(
main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval,
use_galore, galore_8bit, galore_r, galore_update_proj_gap, galore_scale, galore_proj_type
main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, galore
)


Expand All @@ -126,12 +120,7 @@ def main(
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
use_galore: bool,
galore_8bit: bool,
galore_r: int,
galore_update_proj_gap: int,
galore_scale: float,
galore_proj_type: str,
galore: GaLoreArgs,
) -> None:
validate_args(train, eval)

Expand All @@ -153,21 +142,21 @@ def main(

model = fabric.setup(model)

if use_galore:
if galore.use_galore:

linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters; might add options to target specific layers later)
param_groups = [
{'params': nonlinear_params},
{
'params': linear_params,
'rank': galore_r,
'update_proj_gap': galore_update_proj_gap,
'scale': galore_scale,
'proj_type': galore_proj_type
'rank': galore.galore_r,
'update_proj_gap': galore.galore_update_proj_gap,
'scale': galore.galore_scale,
'proj_type': galore.galore_proj_type
}
]
if galore_8bit:
if galore.galore_8bit:
from galore_torch import GaLoreAdamW8bit
optimizer = GaLoreAdamW8bit(
param_groups, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2)
Expand Down
48 changes: 18 additions & 30 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import DataLoader
from torchmetrics import RunningMean

from litgpt.args import EvalArgs, TrainArgs
from litgpt.args import EvalArgs, TrainArgs, GaLoreArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
Expand Down Expand Up @@ -53,12 +53,13 @@ def setup(
lora_projection: bool = False,
lora_mlp: bool = False,
lora_head: bool = False,
use_galore: bool = False,
galore_8bit: bool = False,
galore_r: int = 128,
galore_update_proj_gap: int = 200,
galore_scale: float = 0.25,
galore_proj_type: Literal["std", "reverse_std"] = "std",
galore: GaLoreArgs = GaLoreArgs(
galore_8bit=False,
galore_r=128,
galore_update_proj_gap=200,
galore_scale=0.25,
galore_proj_type="std",
),
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
Expand Down Expand Up @@ -91,13 +92,7 @@ def setup(
lora_projection: Whether to apply LoRA to the output projection in the attention block.
lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.
lora_head: Whether to apply LoRA to output head in GPT.
use_galore: Whether to enable GaLore (GaLore is applied to all linear layers).
use_galore_8bit: Whether to use the 8-bit GaLore AdamW optimizer
instead of the Galore AdamW optimizer.
galore_r: GaLore rank,
galore_update_proj_gap: GaLore hyperparameter,
galore_scale: GaLore scale factor,
galore_proj_type: GaLore projection type,
galore: GaLore-related arguments. See ``litgpt.args.GaLoreArgs`` for details.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
Expand Down Expand Up @@ -152,12 +147,10 @@ def setup(

fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.launch(
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval,
use_galore, galore_8bit, galore_r, galore_update_proj_gap, galore_scale, galore_proj_type
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, galore,
)



def main(
fabric: L.Fabric,
devices: int,
Expand All @@ -168,12 +161,7 @@ def main(
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
use_galore: bool,
galore_8bit: bool,
galore_r: int,
galore_update_proj_gap: int,
galore_scale: float,
galore_proj_type: str,
galore: GaLoreArgs,
) -> None:
validate_args(train, eval)

Expand Down Expand Up @@ -202,26 +190,26 @@ def main(
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
import bitsandbytes as bnb

if use_galore:
if galore.use_galore:
raise ValueError("The combinatiomn of QLoRA and GaLore is currently not supported.")

optimizer_cls = bnb.optim.PagedAdamW

elif use_galore:
elif galore.use_galore:

linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters; might add options to target specific layers later)
trainable_params = [
{'params': nonlinear_params},
{
'params': linear_params,
'rank': galore_r,
'update_proj_gap': galore_update_proj_gap,
'scale': galore_scale,
'proj_type': galore_proj_type
'rank': galore.galore_r,
'update_proj_gap': galore.galore_update_proj_gap,
'scale': galore.galore_scale,
'proj_type': galore.galore_proj_type
}
]
if galore_8bit:
if galore.galore_8bit:
from galore_torch import GaLoreAdamW8bit
optimizer_cls = GaLoreAdamW8bit
else:
Expand Down

0 comments on commit b4257ca

Please sign in to comment.