Skip to content

Commit

Permalink
refactor instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 22, 2024
1 parent 76fc545 commit 695dbd4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 68 deletions.
22 changes: 4 additions & 18 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@

import lightning as L
import torch
from lightning.pytorch.cli import instantiate_class
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import RunningMean

Expand All @@ -29,9 +27,10 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
get_argument_names,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
instantiate_bnb_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
Expand Down Expand Up @@ -151,22 +150,9 @@ def main(
model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")

import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model.parameters())
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()}
optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs)
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
else:
if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model.parameters())
else:
optimizer = instantiate_class(model.parameters(), optimizer)
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
Expand Down
21 changes: 4 additions & 17 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import lightning as L
import torch
from lightning.pytorch.cli import instantiate_class
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
Expand All @@ -28,9 +27,10 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
get_argument_names,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
instantiate_bnb_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
Expand Down Expand Up @@ -150,22 +150,9 @@ def main(
model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")

import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model.parameters())
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()}
optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs)
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
else:
if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model.parameters())
else:
optimizer = instantiate_class(model.parameters(), optimizer)
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
Expand Down
9 changes: 2 additions & 7 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import lightning as L
import torch
from lightning.pytorch.cli import instantiate_class
from lightning.fabric.strategies import FSDPStrategy
from torch.utils.data import DataLoader
from torchmetrics import RunningMean
Expand All @@ -29,6 +28,7 @@
get_default_supported_precision,
load_checkpoint,
init_out_dir,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
save_hyperparameters,
Expand Down Expand Up @@ -136,12 +136,7 @@ def main(

model = fabric.setup(model)

if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model.parameters())
else:
optimizer = instantiate_class(model.parameters(), optimizer)

optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0}
Expand Down
21 changes: 4 additions & 17 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import lightning as L
import torch
from lightning.pytorch.cli import instantiate_class
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
Expand All @@ -29,10 +28,11 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
get_argument_names,
get_default_supported_precision,
load_checkpoint,
init_out_dir,
instantiate_torch_optimizer,
instantiate_bnb_optimizer,
num_parameters,
parse_devices,
save_hyperparameters,
Expand Down Expand Up @@ -180,22 +180,9 @@ def main(
model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")

import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model.parameters())
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()}
optimizer = bnb.optim.PagedAdamW(model.parameters(), **allowed_kwargs)
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
else:
if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model.parameters())
else:
optimizer = instantiate_class(model.parameters(), optimizer)
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
Expand Down
12 changes: 3 additions & 9 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import lightning as L
import torch
import torch.nn as nn
from lightning.pytorch.cli import instantiate_class
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops
from torch.utils.data import DataLoader
Expand All @@ -24,14 +23,14 @@
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
reset_parameters,
Expand Down Expand Up @@ -175,13 +174,8 @@ def main(
model = torch.compile(model)
model = fabric.setup(model)

if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model.parameters(), fused=fabric.device.type == "cuda")
else:
#optimizer["fused"] = fabric.device.type == "cuda"
optimizer = instantiate_class(model.parameters(), optimizer)

# TODO: add fused=fabric.device.type == "cuda"
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)

train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)
Expand Down
24 changes: 24 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.cli import instantiate_class
from torch.serialization import normalize_storage_type
from typing_extensions import Self

Expand Down Expand Up @@ -492,3 +493,26 @@ def get_argument_names(cls):
sig = inspect.signature(cls.__init__)
return {name for name, param in sig.parameters.items()
if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]}


def instantiate_bnb_optimizer(optimizer, model_parameters):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")

import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model_parameters)
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()}
optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
return optimizer


def instantiate_torch_optimizer(optimizer, model_parameters):
if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model_parameters)
else:
optimizer = instantiate_class(model_parameters, optimizer)
return optimizer

0 comments on commit 695dbd4

Please sign in to comment.