Skip to content

Commit

Permalink
Support quantization with adapter v1 and v2 finetuning (#694)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
safurrier and carmocca authored Jan 19, 2024
1 parent dbbbe3a commit 0f021f3
Show file tree
Hide file tree
Showing 7 changed files with 544 additions and 84 deletions.
44 changes: 32 additions & 12 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple

import lightning as L
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor

Expand All @@ -23,7 +24,7 @@
check_valid_checkpoint_dir,
chunked_cross_entropy,
get_default_supported_precision,
lazy_load,
load_checkpoint,
num_parameters,
)
from scripts.prepare_alpaca import generate_prompt
Expand Down Expand Up @@ -56,11 +57,24 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter/alpaca"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
) -> None:
precision = precision or get_default_supported_precision(training=True)

fabric_devices = devices
if fabric_devices > 1:
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

if devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
" --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
Expand All @@ -72,7 +86,7 @@ def setup(
strategy = "auto"

logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)

Expand All @@ -91,20 +105,26 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path)
config = Config.from_name(name=checkpoint_dir.name)
checkpoint_path = checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
with fabric.init_module(empty_init=(devices > 1)):
model = GPT(config)
checkpoint = lazy_load(checkpoint_path)
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)

mark_only_adapter_as_trainable(model)

fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)

trainable_params = [p for p in model.parameters() if p.requires_grad]
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
import bitsandbytes as bnb

optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
optimizer = fabric.setup_optimizers(optimizer)

optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

fabric.seed_everything(1337 + fabric.global_rank)

Expand Down
44 changes: 32 additions & 12 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple

import lightning as L
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor

Expand All @@ -23,7 +24,7 @@
check_valid_checkpoint_dir,
chunked_cross_entropy,
get_default_supported_precision,
lazy_load,
load_checkpoint,
num_parameters,
)
from scripts.prepare_alpaca import generate_prompt
Expand Down Expand Up @@ -56,11 +57,24 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter_v2/alpaca"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
) -> None:
precision = precision or get_default_supported_precision(training=True)

fabric_devices = devices
if fabric_devices > 1:
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

if devices > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
" --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
Expand All @@ -72,7 +86,7 @@ def setup(
strategy = "auto"

logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)

Expand All @@ -91,20 +105,26 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path)
config = Config.from_name(name=checkpoint_dir.name)
checkpoint_path = checkpoint_dir / "lit_model.pth"
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=False):
with fabric.init_module(empty_init=(devices > 1)):
model = GPT(config)
checkpoint = lazy_load(checkpoint_path)
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)

mark_only_adapter_v2_as_trainable(model)

fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)

trainable_params = [p for p in model.parameters() if p.requires_grad]
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
import bitsandbytes as bnb

optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
optimizer = fabric.setup_optimizers(optimizer)

optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

fabric.seed_everything(1337 + fabric.global_rank)

Expand Down
113 changes: 110 additions & 3 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from contextlib import redirect_stdout
from dataclasses import asdict
from io import StringIO
from unittest.mock import Mock

import pytest
import torch
from conftest import RunIf
from lightning import Fabric
from lightning.fabric.wrappers import _FabricOptimizer


def test_config_identical():
Expand Down Expand Up @@ -67,8 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
monkeypatch.setitem(name_to_config, "tmp", model_config)

monkeypatch.setattr(module, "lazy_load", Mock())
monkeypatch.setattr(module.GPT, "load_state_dict", Mock())
monkeypatch.setattr(module, "load_checkpoint", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down Expand Up @@ -129,3 +129,110 @@ def test_adapter_compile():
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0


@RunIf(min_cuda_gpus=1)
# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4
@pytest.mark.xfail(raises=AttributeError, strict=False)
def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir):
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision

if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")

from bitsandbytes.optim import PagedAdamW

import finetune.adapter as module

data = []
torch.save(data, tmp_path / "train.pt")
torch.save(data, tmp_path / "test.pt")

from lit_gpt.config import name_to_config

model_config = dict(
block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True
)
monkeypatch.setitem(name_to_config, "tmp", model_config)

monkeypatch.setattr(module, "load_checkpoint", Mock())
train_mock = Mock()
monkeypatch.setattr(module, "train", train_mock)

stdout = StringIO()
with redirect_stdout(stdout):
module.setup(
data_dir=tmp_path,
checkpoint_dir=fake_checkpoint_dir,
out_dir=tmp_path,
precision="16-true",
quantize="bnb.nf4-dq",
)

args, kwargs = train_mock.call_args
fabric, model, optimizer, *_ = args
assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)
assert isinstance(optimizer, _FabricOptimizer)
assert isinstance(optimizer._optimizer, PagedAdamW)

dtype_to_name = {"torch.uint8": set(), "torch.float16": set()}
for name, layer in model.named_parameters():
name = name[len("_forward_module.") :]
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.wte.weight",
"transformer.h.0.norm_1.weight",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.0.attn.attn.bias",
"transformer.h.0.attn.proj.bias",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.0.norm_2.weight",
"transformer.h.0.norm_2.bias",
"transformer.h.0.mlp.fc.bias",
"transformer.h.0.mlp.proj.bias",
"transformer.h.1.norm_1.weight",
"transformer.h.1.norm_1.bias",
"transformer.h.1.attn.gating_factor",
"transformer.h.1.attn.attn.bias",
"transformer.h.1.attn.proj.bias",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.norm_2.weight",
"transformer.h.1.norm_2.bias",
"transformer.h.1.mlp.fc.bias",
"transformer.h.1.mlp.proj.bias",
"transformer.ln_f.weight",
"transformer.ln_f.bias",
},
"torch.uint8": {
"lm_head.weight",
"transformer.h.0.attn.attn.weight",
"transformer.h.0.attn.proj.weight",
"transformer.h.0.mlp.fc.weight",
"transformer.h.0.mlp.proj.weight",
"transformer.h.1.attn.attn.weight",
"transformer.h.1.attn.proj.weight",
"transformer.h.1.mlp.fc.weight",
"transformer.h.1.mlp.proj.weight",
},
}

assert {p.name for p in tmp_path.glob("*.pth")} == {"lit_model_adapter_finetuned.pth"}
state_dict = torch.load(tmp_path / "lit_model_adapter_finetuned.pth")
assert len(state_dict) == 1
dtype_to_name = {"torch.float16": set()}
for name, layer in state_dict["model"].items():
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.0.attn.gating_factor",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.attn.gating_factor",
}
}

logs = stdout.getvalue()
assert "of trainable parameters: 168" in logs
assert "of non trainable parameters: 1,888" in logs
Loading

0 comments on commit 0f021f3

Please sign in to comment.