Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge the final LoRA finetuned checkpoint #1081

Merged
merged 12 commits into from
Mar 13, 2024
Merged
6 changes: 6 additions & 0 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litgpt.generate.base import next_token
from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.prompts import load_prompt_style, has_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


Expand Down Expand Up @@ -136,6 +137,11 @@ def main(

checkpoint_path = checkpoint_dir / "lit_model.pth"

# Merge if this is a raw LoRA checkpoint
if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_path)

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
2 changes: 2 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
CLI,
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
merge_lora(checkpoint_dir=save_path.parent)


def fit(
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def merge_lora(
lora_params, pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir)
precision = precision if precision is not None else lora_precision

fabric = L.Fabric(devices=1, precision=precision)
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module(empty_init=True):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
monkeypatch.setitem(name_to_config, "tmp", model_config)
monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())

tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
Expand Down Expand Up @@ -622,6 +623,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)

monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())
train_mock = Mock()
monkeypatch.setattr(module, "fit", train_mock)

Expand Down
42 changes: 13 additions & 29 deletions tutorials/finetune_lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ For additional benchmarks and resource requirements, please see the [Resource Ta
You can test the finetuned model with your own instructions by running:

```bash
litgpt generate lora \
litgpt generate base \
--checkpoint_dir "out/lora/final" \
--prompt "Recommend a movie to watch on the weekend."
```

Expand Down Expand Up @@ -119,38 +120,21 @@ You can easily train on your own instruction dataset saved in JSON format.

 

## Merging LoRA Weights

Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which has is much smaller than the original model checkpoint to conserve storage space. If desired, there is the option to merge these LoRA weights directly into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.
## Merging LoRA Weights (Optional)

For example, after finetuning a model using LoRA with the following command:
Finetuning a model with LoRA generates a `lit_model.pth.lora` file.
This file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space.

```bash
litgpt finetune lora \
--checkpoint_dir "checkpoints/stabilityai/stablelm-base-alpha-3b/" \
--train_data_dir data/mydata --val_data_dir data/mydata/ \
--out_dir "out/lora/stablelm-base-alpha-3b/"
```

This code will produce a `lit_model.pth.lora` file in the specified output directory, containing only the LoRA weights. To merge these LoRA weights with the original model checkpoint, you can use the `merge_lora.py` script as follows:

```bash
litgpt merge_lora \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final"
```

Executing this script results in the creation of a full `lit_model.pth` checkpoint that can be used with the `generate/base.py` or `chat/base.py` scripts for inference:
> [!NOTE]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
> LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`.
> Manual merging is only necessary if you want to use the checkpoint outside LitGPT.

```bash
litgpt generate base \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final"
```
If desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint.
The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.

Similarly, you can evaluate the model using the `eval/lm_eval_harness.py` script (see the [evaluation](evaluation.md) tutorial for more information):
For example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows:

```bash
python eval/lm_eval_harness.py \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" \
--precision "bf16-true" \
--save_filepath "results.json"
litgpt merge_lora --checkpoint_dir "out/lora/step-002000"
```
The command above creates a full `lit_model.pth` checkpoint file.
Loading