From 0f021f3ad8cd8d6fe30b0ef721a7a0e2dac15898 Mon Sep 17 00:00:00 2001 From: Alex Furrier Date: Thu, 18 Jan 2024 18:33:08 -0700 Subject: [PATCH] Support quantization with adapter v1 and v2 finetuning (#694) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- finetune/adapter.py | 44 +++++++--- finetune/adapter_v2.py | 44 +++++++--- tests/test_adapter.py | 113 +++++++++++++++++++++++- tests/test_adapter_v2.py | 157 +++++++++++++++++++++++++++++++++- tests/test_lora.py | 121 +++++++++++++++++++++++++- tutorials/finetune_adapter.md | 16 ++++ tutorials/resource-tables.md | 133 ++++++++++++++++------------ 7 files changed, 544 insertions(+), 84 deletions(-) diff --git a/finetune/adapter.py b/finetune/adapter.py index ea16344fca..cab6d78408 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -4,11 +4,12 @@ import sys import time from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple import lightning as L import torch from lightning.fabric.loggers import CSVLogger +from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor @@ -23,7 +24,7 @@ check_valid_checkpoint_dir, chunked_cross_entropy, get_default_supported_precision, - lazy_load, + load_checkpoint, num_parameters, ) from scripts.prepare_alpaca import generate_prompt @@ -56,11 +57,24 @@ def setup( checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), out_dir: Path = Path("out/adapter/alpaca"), precision: Optional[str] = None, + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, ) -> None: precision = precision or get_default_supported_precision(training=True) - fabric_devices = devices - if fabric_devices > 1: + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" + " --quantize flag." + ) strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, @@ -72,7 +86,7 @@ def setup( strategy = "auto" logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) - fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger) + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) fabric.print(hparams) fabric.launch(main, data_dir, checkpoint_dir, out_dir) @@ -91,20 +105,26 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) config = Config.from_name(name=checkpoint_dir.name) checkpoint_path = checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") - with fabric.init_module(empty_init=False): + with fabric.init_module(empty_init=(devices > 1)): model = GPT(config) - checkpoint = lazy_load(checkpoint_path) - # strict=False because missing keys due to adapter weights not contained in state dict - model.load_state_dict(checkpoint, strict=False) - mark_only_adapter_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}") + + model = fabric.setup_module(model) + trainable_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + import bitsandbytes as bnb + + optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) + else: + optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) + optimizer = fabric.setup_optimizers(optimizer) - optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) - model, optimizer = fabric.setup(model, optimizer) + # strict=False because missing keys due to Adapter weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) fabric.seed_everything(1337 + fabric.global_rank) diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 3095210ca9..89d16f790c 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -4,11 +4,12 @@ import sys import time from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple import lightning as L import torch from lightning.fabric.loggers import CSVLogger +from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor @@ -23,7 +24,7 @@ check_valid_checkpoint_dir, chunked_cross_entropy, get_default_supported_precision, - lazy_load, + load_checkpoint, num_parameters, ) from scripts.prepare_alpaca import generate_prompt @@ -56,11 +57,24 @@ def setup( checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), out_dir: Path = Path("out/adapter_v2/alpaca"), precision: Optional[str] = None, + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, ) -> None: precision = precision or get_default_supported_precision(training=True) - fabric_devices = devices - if fabric_devices > 1: + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" + " --quantize flag." + ) strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, @@ -72,7 +86,7 @@ def setup( strategy = "auto" logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) - fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger) + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) fabric.print(hparams) fabric.launch(main, data_dir, checkpoint_dir, out_dir) @@ -91,20 +105,26 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) config = Config.from_name(name=checkpoint_dir.name) checkpoint_path = checkpoint_dir / "lit_model.pth" fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") - with fabric.init_module(empty_init=False): + with fabric.init_module(empty_init=(devices > 1)): model = GPT(config) - checkpoint = lazy_load(checkpoint_path) - # strict=False because missing keys due to adapter weights not contained in state dict - model.load_state_dict(checkpoint, strict=False) - mark_only_adapter_v2_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}") + + model = fabric.setup_module(model) + trainable_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + import bitsandbytes as bnb + + optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) + else: + optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) + optimizer = fabric.setup_optimizers(optimizer) - optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay) - model, optimizer = fabric.setup(model, optimizer) + # strict=False because missing keys due to Adapter weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) fabric.seed_everything(1337 + fabric.global_rank) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 64da985741..c182c5d0de 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,13 +1,14 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - from contextlib import redirect_stdout from dataclasses import asdict from io import StringIO from unittest.mock import Mock +import pytest import torch from conftest import RunIf from lightning import Fabric +from lightning.fabric.wrappers import _FabricOptimizer def test_config_identical(): @@ -67,8 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) monkeypatch.setitem(name_to_config, "tmp", model_config) - monkeypatch.setattr(module, "lazy_load", Mock()) - monkeypatch.setattr(module.GPT, "load_state_dict", Mock()) + monkeypatch.setattr(module, "load_checkpoint", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock @@ -129,3 +129,110 @@ def test_adapter_compile(): assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 + + +@RunIf(min_cuda_gpus=1) +# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4 +@pytest.mark.xfail(raises=AttributeError, strict=False) +def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): + from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision + + if not _BITSANDBYTES_AVAILABLE: + pytest.skip("BNB not available") + + from bitsandbytes.optim import PagedAdamW + + import finetune.adapter as module + + data = [] + torch.save(data, tmp_path / "train.pt") + torch.save(data, tmp_path / "test.pt") + + from lit_gpt.config import name_to_config + + model_config = dict( + block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True + ) + monkeypatch.setitem(name_to_config, "tmp", model_config) + + monkeypatch.setattr(module, "load_checkpoint", Mock()) + train_mock = Mock() + monkeypatch.setattr(module, "train", train_mock) + + stdout = StringIO() + with redirect_stdout(stdout): + module.setup( + data_dir=tmp_path, + checkpoint_dir=fake_checkpoint_dir, + out_dir=tmp_path, + precision="16-true", + quantize="bnb.nf4-dq", + ) + + args, kwargs = train_mock.call_args + fabric, model, optimizer, *_ = args + assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) + assert isinstance(optimizer, _FabricOptimizer) + assert isinstance(optimizer._optimizer, PagedAdamW) + + dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} + for name, layer in model.named_parameters(): + name = name[len("_forward_module.") :] + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.float16": { + "transformer.wte.weight", + "transformer.h.0.norm_1.weight", + "transformer.h.0.norm_1.bias", + "transformer.h.0.attn.gating_factor", + "transformer.h.0.attn.attn.bias", + "transformer.h.0.attn.proj.bias", + "transformer.h.0.attn.adapter_wte.weight", + "transformer.h.0.norm_2.weight", + "transformer.h.0.norm_2.bias", + "transformer.h.0.mlp.fc.bias", + "transformer.h.0.mlp.proj.bias", + "transformer.h.1.norm_1.weight", + "transformer.h.1.norm_1.bias", + "transformer.h.1.attn.gating_factor", + "transformer.h.1.attn.attn.bias", + "transformer.h.1.attn.proj.bias", + "transformer.h.1.attn.adapter_wte.weight", + "transformer.h.1.norm_2.weight", + "transformer.h.1.norm_2.bias", + "transformer.h.1.mlp.fc.bias", + "transformer.h.1.mlp.proj.bias", + "transformer.ln_f.weight", + "transformer.ln_f.bias", + }, + "torch.uint8": { + "lm_head.weight", + "transformer.h.0.attn.attn.weight", + "transformer.h.0.attn.proj.weight", + "transformer.h.0.mlp.fc.weight", + "transformer.h.0.mlp.proj.weight", + "transformer.h.1.attn.attn.weight", + "transformer.h.1.attn.proj.weight", + "transformer.h.1.mlp.fc.weight", + "transformer.h.1.mlp.proj.weight", + }, + } + + assert {p.name for p in tmp_path.glob("*.pth")} == {"lit_model_adapter_finetuned.pth"} + state_dict = torch.load(tmp_path / "lit_model_adapter_finetuned.pth") + assert len(state_dict) == 1 + dtype_to_name = {"torch.float16": set()} + for name, layer in state_dict["model"].items(): + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.float16": { + "transformer.h.0.attn.adapter_wte.weight", + "transformer.h.0.attn.gating_factor", + "transformer.h.1.attn.adapter_wte.weight", + "transformer.h.1.attn.gating_factor", + } + } + + logs = stdout.getvalue() + assert "of trainable parameters: 168" in logs + assert "of non trainable parameters: 1,888" in logs diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 214040ba9f..6c961927ca 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -10,6 +10,7 @@ import torch from conftest import RunIf from lightning import Fabric +from lightning.fabric.wrappers import _FabricOptimizer # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -91,8 +92,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) monkeypatch.setitem(name_to_config, "tmp", model_config) - monkeypatch.setattr(module, "lazy_load", Mock()) - monkeypatch.setattr(module.GPT, "load_state_dict", Mock()) + monkeypatch.setattr(module, "load_checkpoint", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock @@ -219,3 +219,156 @@ def test_against_hf_mixtral(): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +@RunIf(min_cuda_gpus=1) +# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4 +@pytest.mark.xfail(raises=AttributeError, strict=False) +def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): + from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision + + if not _BITSANDBYTES_AVAILABLE: + pytest.skip("BNB not available") + + from bitsandbytes.optim import PagedAdamW + + import finetune.adapter_v2 as module + + data = [] + torch.save(data, tmp_path / "train.pt") + torch.save(data, tmp_path / "test.pt") + + from lit_gpt.config import name_to_config + + model_config = dict( + block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True + ) + monkeypatch.setitem(name_to_config, "tmp", model_config) + + monkeypatch.setattr(module, "load_checkpoint", Mock()) + train_mock = Mock() + monkeypatch.setattr(module, "train", train_mock) + + stdout = StringIO() + with redirect_stdout(stdout): + module.setup( + data_dir=tmp_path, + checkpoint_dir=fake_checkpoint_dir, + out_dir=tmp_path, + precision="16-true", + quantize="bnb.nf4-dq", + ) + + args, kwargs = train_mock.call_args + fabric, model, optimizer, *_ = args + assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) + assert isinstance(optimizer, _FabricOptimizer) + assert isinstance(optimizer._optimizer, PagedAdamW) + + dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} + for name, layer in model.named_parameters(): + name = name[len("_forward_module.") :] + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.uint8": { + "transformer.h.0.mlp.fc.linear.weight", + "transformer.h.1.mlp.proj.linear.weight", + "transformer.h.1.attn.attn.linear.weight", + "transformer.h.0.attn.proj.linear.weight", + "lm_head.linear.weight", + "transformer.h.1.attn.proj.linear.weight", + "transformer.h.0.mlp.proj.linear.weight", + "transformer.h.0.attn.attn.linear.weight", + "transformer.h.1.mlp.fc.linear.weight", + }, + "torch.float16": { + "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.mlp.proj.adapter_bias", + "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.norm_1.bias", + "transformer.h.0.attn.attn.linear.bias", + "transformer.h.1.attn.adapter_wte.weight", + "transformer.ln_f.weight", + "transformer.h.0.mlp.fc.linear.bias", + "transformer.h.0.mlp.proj.linear.bias", + "transformer.h.1.mlp.fc.linear.bias", + "transformer.h.0.attn.proj.adapter_scale", + "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.1.norm_2.bias", + "transformer.h.1.attn.proj.adapter_scale", + "transformer.h.0.norm_2.bias", + "transformer.h.0.mlp.fc.adapter_scale", + "transformer.h.0.attn.proj.linear.bias", + "transformer.h.1.attn.proj.linear.bias", + "transformer.h.1.norm_1.bias", + "transformer.h.0.norm_1.weight", + "transformer.h.1.attn.proj.adapter_bias", + "transformer.h.0.mlp.proj.adapter_scale", + "transformer.h.0.mlp.proj.adapter_bias", + "transformer.h.1.mlp.fc.adapter_bias", + "transformer.h.1.mlp.proj.adapter_scale", + "transformer.h.1.attn.gating_factor", + "transformer.h.1.norm_1.weight", + "transformer.ln_f.bias", + "transformer.h.0.mlp.fc.adapter_bias", + "lm_head.adapter_scale", + "lm_head.adapter_bias", + "transformer.h.1.norm_2.weight", + "transformer.h.0.attn.adapter_wte.weight", + "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.mlp.fc.adapter_scale", + "transformer.h.1.attn.attn.linear.bias", + "transformer.wte.weight", + "transformer.h.0.norm_2.weight", + "transformer.h.1.mlp.proj.linear.bias", + "transformer.h.0.attn.gating_factor", + "transformer.h.0.attn.proj.adapter_bias", + }, + } + + assert {p.name for p in tmp_path.glob("*.pth")} == {"lit_model_adapter_finetuned.pth"} + state_dict = torch.load(tmp_path / "lit_model_adapter_finetuned.pth") + assert len(state_dict) == 1 + dtype_to_name = {"torch.float16": set()} + for name, layer in state_dict["model"].items(): + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.float16": { + "transformer.h.1.attn.adapter_wte.weight", + "transformer.h.1.attn.proj.adapter_bias", + "transformer.h.1.mlp.fc.adapter_scale", + "lm_head.adapter_bias", + "transformer.h.0.mlp.proj.adapter_scale", + "transformer.ln_f.bias", + "lm_head.adapter_scale", + "transformer.h.1.norm_2.weight", + "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.mlp.proj.adapter_bias", + "transformer.h.0.attn.gating_factor", + "transformer.h.1.norm_1.bias", + "transformer.h.1.mlp.fc.adapter_bias", + "transformer.h.1.mlp.proj.adapter_scale", + "transformer.h.0.mlp.fc.adapter_scale", + "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.0.norm_2.weight", + "transformer.h.1.norm_2.bias", + "transformer.h.0.norm_1.weight", + "transformer.h.0.attn.proj.adapter_scale", + "transformer.h.1.mlp.proj.adapter_bias", + "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.adapter_wte.weight", + "transformer.ln_f.weight", + "transformer.h.1.attn.gating_factor", + "transformer.h.0.mlp.fc.adapter_bias", + "transformer.h.1.attn.proj.adapter_scale", + "transformer.h.0.attn.proj.adapter_bias", + "transformer.h.0.norm_1.bias", + "transformer.h.0.norm_2.bias", + "transformer.h.1.norm_1.weight", + "transformer.h.1.attn.attn.adapter_scale", + } + } + + logs = stdout.getvalue() + assert "of trainable parameters: 552" in logs + assert "of non trainable parameters: 1,808" in logs diff --git a/tests/test_lora.py b/tests/test_lora.py index 1d9364169b..da9967a52b 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -11,6 +11,7 @@ import torch from conftest import RunIf from lightning import Fabric +from lightning.fabric.wrappers import _FabricOptimizer # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -362,7 +363,7 @@ def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merge @RunIf(min_cuda_gpus=1) # platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4 @pytest.mark.xfail(raises=AttributeError, strict=False) -def test_lora_merge_with_quantize(): +def test_lora_merge_with_bitsandbytes(): from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision if not _BITSANDBYTES_AVAILABLE: @@ -548,3 +549,121 @@ def test_against_hf_mixtral(): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +@RunIf(min_cuda_gpus=1) +# platform dependent cuda issue: libbitsandbytes_cpu.so: undefined symbol: cquantize_blockwise_fp16_nf4 +@pytest.mark.xfail(raises=AttributeError, strict=False) +def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir): + from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision + + if not _BITSANDBYTES_AVAILABLE: + pytest.skip("BNB not available") + + from bitsandbytes.optim import PagedAdamW + + import finetune.lora as module + + data = [] + torch.save(data, tmp_path / "train.pt") + torch.save(data, tmp_path / "test.pt") + + from lit_gpt.config import name_to_config + + model_config = dict( + block_size=128, + n_layer=2, + n_embd=8, + n_head=4, + padded_vocab_size=8, + bias=True, + r=8, + alpha=8, + dropout=0.1, + to_query=True, + to_value=True, + to_projection=True, + ) + monkeypatch.setitem(name_to_config, "tmp", model_config) + + monkeypatch.setattr(module, "load_checkpoint", Mock()) + train_mock = Mock() + monkeypatch.setattr(module, "train", train_mock) + + stdout = StringIO() + with redirect_stdout(stdout): + module.setup( + data_dir=tmp_path, + checkpoint_dir=fake_checkpoint_dir, + out_dir=tmp_path, + precision="16-true", + quantize="bnb.nf4-dq", + ) + + args, kwargs = train_mock.call_args + fabric, model, optimizer, *_ = args + assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) + assert isinstance(optimizer, _FabricOptimizer) + assert isinstance(optimizer._optimizer, PagedAdamW) + + dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} + for name, layer in model.named_parameters(): + name = name[len("_forward_module.") :] + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.uint8": { + "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.proj.linear.weight", + "transformer.h.0.mlp.fc.linear.weight", + "transformer.h.1.mlp.proj.linear.weight", + "transformer.h.0.mlp.proj.linear.weight", + "transformer.h.1.attn.attn.linear.weight", + "lm_head.linear.weight", + "transformer.h.1.attn.proj.linear.weight", + "transformer.h.1.mlp.fc.linear.weight", + }, + "torch.float16": { + "transformer.h.0.attn.attn.lora_B", + "transformer.h.0.norm_2.weight", + "transformer.wte.weight", + "transformer.h.1.mlp.fc.linear.bias", + "transformer.ln_f.bias", + "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.proj.linear.bias", + "transformer.h.1.norm_1.weight", + "transformer.h.1.attn.attn.linear.bias", + "transformer.h.1.attn.attn.lora_A", + "transformer.h.1.norm_1.bias", + "transformer.h.1.norm_2.bias", + "transformer.h.0.attn.proj.linear.bias", + "transformer.h.0.norm_1.bias", + "transformer.h.0.mlp.proj.linear.bias", + "transformer.h.0.mlp.fc.linear.bias", + "transformer.h.0.norm_2.bias", + "transformer.ln_f.weight", + "transformer.h.0.attn.attn.lora_A", + "transformer.h.1.norm_2.weight", + "transformer.h.1.mlp.proj.linear.bias", + "transformer.h.0.norm_1.weight", + "transformer.h.0.attn.attn.linear.bias", + }, + } + + assert {p.name for p in tmp_path.glob("*.pth")} == {"lit_model_lora_finetuned.pth"} + state_dict = torch.load(tmp_path / "lit_model_lora_finetuned.pth") + assert len(state_dict) == 1 + dtype_to_name = {"torch.float16": set()} + for name, layer in state_dict["model"].items(): + dtype_to_name[str(layer.dtype)].add(name) + assert dtype_to_name == { + "torch.float16": { + "transformer.h.1.attn.attn.lora_A", + "transformer.h.0.attn.attn.lora_A", + "transformer.h.0.attn.attn.lora_B", + "transformer.h.1.attn.attn.lora_B", + } + } + + logs = stdout.getvalue() + assert "of trainable parameters: 512" in logs + assert "of non trainable parameters: 1,888" in logs diff --git a/tutorials/finetune_adapter.md b/tutorials/finetune_adapter.md index 6b3e120cf4..916e25af8b 100644 --- a/tutorials/finetune_adapter.md +++ b/tutorials/finetune_adapter.md @@ -69,6 +69,22 @@ python finetune/adapter.py --out_dir out/adapter/my-model-finetuned --precision Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac. +### Quantization + +Optionally, finetuning using quantization can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type: + +```bash +python finetune/adapter.py --quantize "bnb.nf4" +``` + +or using adapter_v2 with double-quantization: + +```bash +python finetune/adapter_v2.py --quantize "bnb.nf4-dq" +``` + +For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md). + ## Test the model You can test the finetuned model with your own instructions by running: diff --git a/tutorials/resource-tables.md b/tutorials/resource-tables.md index c9d55ce6b8..e9cde2a109 100644 --- a/tutorials/resource-tables.md +++ b/tutorials/resource-tables.md @@ -6,10 +6,7 @@ - OS: Ubuntu 22.04.3 LTS (x86_64) - Nvidia driver version: 525.125.06 - Relevant libraries - - CMake 3.26.4 - - Libc glibc-2.35 - PyTorch 2.1.0+cu121 - - Lightning 2.1.0.rc0 - Bitsandbytes 0.41.1 This document provides an overview and examples of hardware requirements when running models in Lit-GPT. @@ -39,35 +36,63 @@ Note that the number of tokens in the training set does not affect the supported The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `finetune/lora.py` script. -| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | Time 50k iter (extrapolated) | -|-------|----------------|--------------|-----------------|----------------------|-------------|--------------------|------------------------------| -| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 4.82 GB | 1.62 min | 80.91 min | -| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 3.78 GB | 1.77 min | 88.36 min | -| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 3.72 GB | 1.87 min | 93.39 min | -| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 6.76 GB | 1.65 min | 82.44 min | -| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 10.68 GB | 1.70 min | 84.79 min | -| | | | | | | | | -| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 9.69 GB | 1.24 min | 62.23 min | -| 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,097,152 | 6.35 GB | 1.82 min | 91.22 min | -| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,097,152 | 6.19 GB | 1.87 min | 93.58 min | -| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 12.10 GB | 1.33 min | 66.68 min | -| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 16.92 GB | 1.50 min | 74.89 min | -| | | | | | | | | -| 7 B | Llama 2 | None | 1 | 4,194,304 | 21.30 GB | 2.36 min | 118.03 min | -| 7 B | Llama 2 | bnb.nf4 | 1 | 4,194,304 | 14.14 GB | 3.68 min | 183.88 min | -| 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,194,304 | 13.84 GB | 3.83 min | 191.66 min | -| 7 B | Llama 2 | None | 2 | 4,194,304 | 29.07 GB | 2.52 min | 125.97 min | -| 7 B | Llama 2 | None | 4 | 4,194,304 | OOM | - | - | -| | | | | | | | | -| 13 B | Llama 2 | None | 1 | 6,553,600 | 38.12 GB | 3.19 min | 159.43 min | -| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 23.14 GB | 6.38 min | 319.03 min | -| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 22.55 GB | 6.55 min | 327.32 min | -| 13 B | Llama 2 | None | 2 | 6,553,600 | OOM | - | - | -| 13 B | Llama 2 | None | 4 | 6,553,600 | OOM | - | - | -| | | | | | | | | -| 40 B | Falcon | None | 1 | 12,042,240 | OOM | - | - | -| 40 B | Falcon | bnb.nf4 | 1 | 12,042,240 | OOM | - | - | -| 40 B | Falcon | bnb.nf4-dq | 1 | 12,042,240 | OOM | - | - | +| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | +|-------|----------------|--------------|-----------------|----------------------|-------------|--------------------| +| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 4.82 GB | 1.62 min | +| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 3.78 GB | 1.77 min | +| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 3.72 GB | 1.87 min | +| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 6.76 GB | 1.65 min | +| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 10.68 GB | 1.70 min | +| | | | | | | | +| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 9.69 GB | 1.24 min | +| 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,097,152 | 6.35 GB | 1.82 min | +| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,097,152 | 6.19 GB | 1.87 min | +| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 12.10 GB | 1.33 min | +| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 16.92 GB | 1.50 min | +| | | | | | | | +| 7 B | Llama 2 | None | 1 | 4,194,304 | 21.30 GB | 2.36 min | +| 7 B | Llama 2 | bnb.nf4 | 1 | 4,194,304 | 14.14 GB | 3.68 min | +| 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,194,304 | 13.84 GB | 3.83 min | +| 7 B | Llama 2 | None | 2 | 4,194,304 | 29.07 GB | 2.52 min | +| 7 B | Llama 2 | None | 4 | 4,194,304 | OOM | - | +| | | | | | | | +| 13 B | Llama 2 | None | 1 | 6,553,600 | 38.12 GB | 3.19 min | +| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 23.14 GB | 6.38 min | +| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 22.55 GB | 6.55 min | +| 13 B | Llama 2 | None | 2 | 6,553,600 | OOM | - | +| 13 B | Llama 2 | None | 4 | 6,553,600 | OOM | - | +| | | | | | | | +| 40 B | Falcon | None | 1 | 12,042,240 | OOM | - | +| 40 B | Falcon | bnb.nf4 | 1 | 12,042,240 | OOM | - | +| 40 B | Falcon | bnb.nf4-dq | 1 | 12,042,240 | OOM | - | + +  + +## Finetuning with Adapter on 1 GPU + +The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `finetune/adapter.py` script. + +| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | +|------|----------------|--------------|-----------------|----------------------|-------------|--------------------| +| 3 B | StableLM Alpha | None | 1 | 573,888 | 9.10 GB | 0.74 min | +| 3 B | StableLM Alpha | bnb.nf4 | 1 | 573,888 | 5.65 GB | 1.38 min | +| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 573,888 | 5.48 GB | 1.46 min | +| | | | | | | | +| 7 B | Llama 2 | None | 1 | 1,229,760 | 19.98 GB | 1.50 min | +| 7 B | Llama 2 | bnb.nf4 | 1 | 1,229,760 | 12.68 GB | 2.93 min | +| 7 B | Llama 2 | bnb.nf4-dq | 1 | 1,229,760 | 12.38 GB | 3.00 min | + +The same config, but using the `finetune/adapter_v2.py` script. + +| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | +|------|----------------|--------------|-----------------|----------------------|-------------|--------------------| +| 3 B | StableLM Alpha | None | 1 | 2,125,248 | 10.71 GB | 0.87 min | +| 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,125,248 | 7.41 GB | 1.59 min | +| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,125,248 | 7.25 GB | 1.62 min | +| | | | | | | | +| 7 B | Llama 2 | None | 1 | 4,279,744 | 25.51 GB | 1.81 min | +| 7 B | Llama 2 | bnb.nf4 | 1 | 4,279,744 | 18.30 GB | 3.23 min | +| 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,279,744 | 17.98 GB | 3.32 min |   @@ -75,28 +100,28 @@ The following experiments were conducted on 1xA100 with a minibatch size of 128 The following experiments were conducted on multiple A100 GPUs with a minibatch size of 128 using the `finetune/lora.py` script. -| Size | Model | Quantization | Microbatch size | Trainable parameters | GPU | Max GPU RAM | Time 1k iterations | Time 50k iter (extrapolated) | -|-------|----------------|--------------|-----------------|----------------------|----------|-------------|--------------------|------------------------------| -| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 2 x A100 | 4.86 GB | 3.81 min | 190.47 min | -| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 2 x A100 | N/A | - | - | -| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 2 x A100 | N/A | - | - | -| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 2 x A100 | 5.05 GB | 3.63 min | 181.31 min | -| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 2 x A100 | 5.88 GB | 3.64 min | 181.76 min | -| | | | | | | | | | -| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 2 x A100 | 12.75 GB | 2.92 min | 145.96 min | -| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 2 x A100 | 12.94 GB | 3.06 min | 153.10 min | -| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 2 x A100 | 13.45 GB | 3.86 min | 192.99 min | -| | | | | | | | - | - | -| 7 B | Llama 2 | None | 1 | 4,194,304 | 2 x A100 | 22.18 GB | 5.93 min | 296.62 min | -| 7 B | Llama 2 | None | 2 | 4,194,304 | 2 x A100 | 22.47 GB | 6.48 min | 324.03 min | -| 7 B | Llama 2 | None | 4 | 4,194,304 | 2 x A100 | 23.39 GB | 8.66 min | 432.82 min | -| | | | | | | | | | -| 13 B | Llama 2 | None | 1 | 6,553,600 | 2 x A100 | OOM | - | - | -| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 2 x A100 | N/A | - | - | -| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 2 x A100 | N/A | - | - | -| | | | | | | | | | -| 13 B | Llama 2 | None | 1 | 6,553,600 | 4 x A100 | 35.57 GB | 10.25 min | 512.5 min | -| 40 B | Falcon | None | 1 | 12,042,240 | 4 x A100 | OOM | - | - | +| Size | Model | Quantization | Microbatch size | Trainable parameters | GPU | Max GPU RAM | Time 1k iterations | +|-------|----------------|--------------|-----------------|----------------------|----------|-------------|--------------------| +| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 2 x A100 | 4.86 GB | 3.81 min | +| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 2 x A100 | N/A | - | +| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 2 x A100 | N/A | - | +| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 2 x A100 | 5.05 GB | 3.63 min | +| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 2 x A100 | 5.88 GB | 3.64 min | +| | | | | | | | | +| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 2 x A100 | 12.75 GB | 2.92 min | +| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 2 x A100 | 12.94 GB | 3.06 min | +| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 2 x A100 | 13.45 GB | 3.86 min | +| | | | | | | | - | +| 7 B | Llama 2 | None | 1 | 4,194,304 | 2 x A100 | 22.18 GB | 5.93 min | +| 7 B | Llama 2 | None | 2 | 4,194,304 | 2 x A100 | 22.47 GB | 6.48 min | +| 7 B | Llama 2 | None | 4 | 4,194,304 | 2 x A100 | 23.39 GB | 8.66 min | +| | | | | | | | | +| 13 B | Llama 2 | None | 1 | 6,553,600 | 2 x A100 | OOM | - | +| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 2 x A100 | N/A | - | +| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 2 x A100 | N/A | - | +| | | | | | | | | +| 13 B | Llama 2 | None | 1 | 6,553,600 | 4 x A100 | 35.57 GB | 10.25 min | +| 40 B | Falcon | None | 1 | 12,042,240 | 4 x A100 | OOM | - |