Skip to content

Commit

Permalink
Merge the final LoRA finetuned checkpoint (#1081)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Raschka <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people committed Mar 15, 2024
1 parent 3350b2a commit 4faebe4
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 30 deletions.
6 changes: 6 additions & 0 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litgpt.generate.base import next_token
from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.prompts import load_prompt_style, has_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


Expand Down Expand Up @@ -136,6 +137,11 @@ def main(

checkpoint_path = checkpoint_dir / "lit_model.pth"

# Merge if this is a raw LoRA checkpoint
if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_path)

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
2 changes: 2 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
CLI,
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
merge_lora(checkpoint_dir=save_path.parent)


def fit(
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def merge_lora(
lora_params, pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir)
precision = precision if precision is not None else lora_precision

fabric = L.Fabric(devices=1, precision=precision)
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module(empty_init=True):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
monkeypatch.setitem(name_to_config, "tmp", model_config)
monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down Expand Up @@ -622,6 +623,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)

monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())
train_mock = Mock()
monkeypatch.setattr(module, "fit", train_mock)

Expand Down
42 changes: 13 additions & 29 deletions tutorials/finetune_lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ For additional benchmarks and resource requirements, please see the [Resource Ta
You can test the finetuned model with your own instructions by running:

```bash
litgpt generate lora \
litgpt generate base \
--checkpoint_dir "out/lora/final" \
--prompt "Recommend a movie to watch on the weekend."
```

Expand Down Expand Up @@ -119,38 +120,21 @@ You can easily train on your own instruction dataset saved in JSON format.
&nbsp;
## Merging LoRA Weights
Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which has is much smaller than the original model checkpoint to conserve storage space. If desired, there is the option to merge these LoRA weights directly into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.
## Merging LoRA Weights (Optional)
For example, after finetuning a model using LoRA with the following command:
Finetuning a model with LoRA generates a `lit_model.pth.lora` file.
This file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space.
```bash
litgpt finetune lora \
--checkpoint_dir "checkpoints/stabilityai/stablelm-base-alpha-3b/" \
--train_data_dir data/mydata --val_data_dir data/mydata/ \
--out_dir "out/lora/stablelm-base-alpha-3b/"
```

This code will produce a `lit_model.pth.lora` file in the specified output directory, containing only the LoRA weights. To merge these LoRA weights with the original model checkpoint, you can use the `merge_lora.py` script as follows:

```bash
litgpt merge_lora \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final"
```

Executing this script results in the creation of a full `lit_model.pth` checkpoint that can be used with the `generate/base.py` or `chat/base.py` scripts for inference:
> [!NOTE]
> LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`.
> Manual merging is only necessary if you want to use the checkpoint outside LitGPT.
```bash
litgpt generate base \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final"
```
If desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint.
The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.
Similarly, you can evaluate the model using the `eval/lm_eval_harness.py` script (see the [evaluation](evaluation.md) tutorial for more information):
For example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows:
```bash
python eval/lm_eval_harness.py \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" \
--precision "bf16-true" \
--save_filepath "results.json"
litgpt merge_lora --checkpoint_dir "out/lora/step-002000"
```
The command above creates a full `lit_model.pth` checkpoint file.

0 comments on commit 4faebe4

Please sign in to comment.