From b6a8326a577ff139f1286365ec54c1517bcec5c7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 14 Mar 2024 00:25:55 +0100 Subject: [PATCH] Merge the final LoRA finetuned checkpoint (#1081) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastian Raschka Co-authored-by: Carlos MocholĂ­ --- litgpt/chat/base.py | 6 ++++++ litgpt/finetune/lora.py | 2 ++ litgpt/scripts/merge_lora.py | 2 +- tests/test_lora.py | 2 ++ tutorials/finetune_lora.md | 42 +++++++++++------------------------- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index 9149a769a3..84023d3a26 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -12,6 +12,7 @@ from litgpt.generate.base import next_token from litgpt import GPT, Config, PromptStyle, Tokenizer from litgpt.prompts import load_prompt_style, has_prompt_style +from litgpt.scripts.merge_lora import merge_lora from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint @@ -136,6 +137,11 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" + # Merge if this is a raw LoRA checkpoint + if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file(): + print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.") + merge_lora(checkpoint_path) + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) with fabric.init_module(empty_init=True): model = GPT(config) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index d8cd1063b2..26c3dce3da 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -20,6 +20,7 @@ from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable from litgpt.prompts import save_prompt_style +from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.utils import ( CLI, @@ -192,6 +193,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) + merge_lora(checkpoint_dir=save_path.parent) def fit( diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index f1d5068688..03f22f8544 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -42,7 +42,7 @@ def merge_lora( lora_params, pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir) precision = precision if precision is not None else lora_precision - fabric = L.Fabric(devices=1, precision=precision) + fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) with fabric.init_module(empty_init=True): diff --git a/tests/test_lora.py b/tests/test_lora.py index 8f59a4b36c..cc22868477 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -190,6 +190,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) monkeypatch.setitem(name_to_config, "tmp", model_config) monkeypatch.setattr(module, "load_checkpoint", Mock()) + monkeypatch.setattr(module, "merge_lora", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock @@ -622,6 +623,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) monkeypatch.setattr(module, "load_checkpoint", Mock()) + monkeypatch.setattr(module, "merge_lora", Mock()) train_mock = Mock() monkeypatch.setattr(module, "fit", train_mock) diff --git a/tutorials/finetune_lora.md b/tutorials/finetune_lora.md index da81c32fc4..19c2201f3c 100644 --- a/tutorials/finetune_lora.md +++ b/tutorials/finetune_lora.md @@ -73,7 +73,8 @@ For additional benchmarks and resource requirements, please see the [Resource Ta You can test the finetuned model with your own instructions by running: ```bash -litgpt generate lora \ +litgpt generate base \ + --checkpoint_dir "out/lora/final" \ --prompt "Recommend a movie to watch on the weekend." ``` @@ -119,38 +120,21 @@ You can easily train on your own instruction dataset saved in JSON format.   -## Merging LoRA Weights - -Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which has is much smaller than the original model checkpoint to conserve storage space. If desired, there is the option to merge these LoRA weights directly into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed. +## Merging LoRA Weights (Optional) -For example, after finetuning a model using LoRA with the following command: +Finetuning a model with LoRA generates a `lit_model.pth.lora` file. +This file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space. -```bash -litgpt finetune lora \ - --checkpoint_dir "checkpoints/stabilityai/stablelm-base-alpha-3b/" \ - --train_data_dir data/mydata --val_data_dir data/mydata/ \ - --out_dir "out/lora/stablelm-base-alpha-3b/" -``` - -This code will produce a `lit_model.pth.lora` file in the specified output directory, containing only the LoRA weights. To merge these LoRA weights with the original model checkpoint, you can use the `merge_lora.py` script as follows: - -```bash -litgpt merge_lora \ - --checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" -``` - -Executing this script results in the creation of a full `lit_model.pth` checkpoint that can be used with the `generate/base.py` or `chat/base.py` scripts for inference: +> [!NOTE] +> LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`. +> Manual merging is only necessary if you want to use the checkpoint outside LitGPT. -```bash -litgpt generate base \ - --checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" -``` +If desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. +The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed. -Similarly, you can evaluate the model using the `eval/lm_eval_harness.py` script (see the [evaluation](evaluation.md) tutorial for more information): +For example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows: ```bash -python eval/lm_eval_harness.py \ - --checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" \ - --precision "bf16-true" \ - --save_filepath "results.json" +litgpt merge_lora --checkpoint_dir "out/lora/step-002000" ``` +The command above creates a full `lit_model.pth` checkpoint file.