Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge the final LoRA finetuned checkpoint #1081

Merged
merged 12 commits into from
Mar 13, 2024
6 changes: 6 additions & 0 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litgpt.generate.base import next_token
from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.prompts import load_prompt_style, has_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


Expand Down Expand Up @@ -136,6 +137,11 @@ def main(

checkpoint_path = checkpoint_dir / "lit_model.pth"

# Merge if this is a raw LoRA checkpoint
if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long.")
merge_lora(checkpoint_path)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
3 changes: 3 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
CLI,
Expand Down Expand Up @@ -192,6 +193,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
fabric.barrier()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
merge_lora(checkpoint_dir=save_path.parent)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


def fit(
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def merge_lora(
lora_params, pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir)
precision = precision if precision is not None else lora_precision

fabric = L.Fabric(devices=1, precision=precision)
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module(empty_init=True):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
monkeypatch.setitem(name_to_config, "tmp", model_config)
monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down Expand Up @@ -622,6 +623,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)

monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())
train_mock = Mock()
monkeypatch.setattr(module, "fit", train_mock)

Expand Down
Loading