Skip to content

Commit

Permalink
change lora weights to lit_model.pth.lora (#1053)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
rasbt and awaelchli committed Mar 18, 2024
1 parent f78adc6 commit 4c43c09
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 26 deletions.
5 changes: 3 additions & 2 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

# Save the final LoRA checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth"
save_path = out_dir / "final" / "lit_model.pth.lora"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, save_path)
if fabric.global_rank == 0:
Expand Down Expand Up @@ -262,8 +262,9 @@ def fit(
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()

if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth"
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
Expand Down
17 changes: 6 additions & 11 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def merge_lora(
) -> None:
"""Merges the LoRA weights with the base model. See `litgpt/finetune/lora.py`.
Merging happens in-place in the checkpoint directory that is given as input. It also saves
a backup file `lit_model.pth.lora` of the trained LoRA weights in case you still need it later.
Creates a new `lit_model.pth` file by merging the LoRA weights (`lit_model.pth.lora`)
with the original checkpoint weights.
Args:
checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of
Expand All @@ -33,10 +33,10 @@ def merge_lora(
precision: Optional precision setting to instantiate the model weights in. By default, this will
automatically be inferred from the metadata in the given `checkpoint_dir` directory.
"""
check_valid_checkpoint_dir(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir, lora=True)
if pretrained_checkpoint_dir is not None:
check_valid_checkpoint_dir(pretrained_checkpoint_dir)
if (checkpoint_dir / "lit_model.pth.lora").is_file():
if (checkpoint_dir / "lit_model.pth").is_file():
print("LoRA weights have already been merged in this checkpoint.")
return

Expand All @@ -49,7 +49,7 @@ def merge_lora(
with fabric.init_module(empty_init=True):
model = GPT(config)

lora_path = checkpoint_dir / "lit_model.pth"
lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = lazy_load(pretrained_checkpoint_dir / "lit_model.pth")
lora_checkpoint = lazy_load(lora_path)

Expand All @@ -60,15 +60,10 @@ def merge_lora(

# Remove LoRA parameters and the LoRA linear substring
state_dict = {k.replace("linear.", ""): v for k, v in model.state_dict().items() if not lora_filter(k, v)}
save_path = checkpoint_dir / "lit_model.pth.merged"
save_path = checkpoint_dir / "lit_model.pth"
torch.save(state_dict, save_path)

# Make a backup of the LoRA weights (they are only a few MBs)
os.rename(checkpoint_dir / "lit_model.pth", checkpoint_dir / "lit_model.pth.lora")
os.rename(checkpoint_dir / "lit_model.pth.merged", checkpoint_dir / "lit_model.pth")

fabric.print(f"Saved merged weights to {str(checkpoint_dir / 'lit_model.pth')!r}")
fabric.print(f"A backup of the old LoRA weights is in {str(checkpoint_dir / 'lit_model.pth.lora')!r}")


def load_lora_metadata(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path, Optional[str]]:
Expand Down
5 changes: 3 additions & 2 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> i
return total


def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
def check_valid_checkpoint_dir(checkpoint_dir: Path, lora: bool = False) -> None:
model_filename = "lit_model.pth.lora" if lora else "lit_model.pth"
files = {
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
model_filename: (checkpoint_dir / model_filename).is_file(),
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth",
"lit_model.pth.lora",
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import yaml


@RunIf(skip_windows=True) # PermissionError in os.rename on Windows
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_merge_lora(tmp_path, fake_checkpoint_dir):
from litgpt.lora import GPT as LoRAGPT
Expand All @@ -26,6 +25,7 @@ def test_merge_lora(tmp_path, fake_checkpoint_dir):
lora_checkpoint_dir = tmp_path / "lora"
shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir)
shutil.copytree(fake_checkpoint_dir, lora_checkpoint_dir)
(lora_checkpoint_dir / "lit_model.pth").unlink() # should not already exist
shutil.rmtree(tmp_path / "checkpoints")

# Create a fake pretrained checkpoint
Expand All @@ -42,7 +42,7 @@ def test_merge_lora(tmp_path, fake_checkpoint_dir):
lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs)
state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)}
assert len(state_dict) == 6
torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth")
torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth.lora")
hparams = dict(checkpoint_dir=str(pretrained_checkpoint_dir), **lora_kwargs)
with open(lora_checkpoint_dir / "hyperparameters.yaml", "w") as file:
yaml.dump(hparams, file)
Expand Down
12 changes: 4 additions & 8 deletions tutorials/finetune_lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,9 @@ You can easily train on your own instruction dataset saved in JSON format.
## Merging LoRA Weights
By default, the LoRA weights are kept separate from the checkpoint file to save storage space.
However, you can optionally merge the LoRA weights with the original model checkpoint to create
a new file to optimize inference speeds. (This will improve inference performance
because the weights don't have to be added during runtime.)
Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which has is much smaller than the original model checkpoint to conserve storage space. If desired, there is the option to merge these LoRA weights directly into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.
Let's assume we finetuned a model using LoRA as follows:
For example, after finetuning a model using LoRA with the following command:
```bash
python litgpt/finetune/lora.py \
Expand All @@ -140,15 +137,14 @@ python litgpt/finetune/lora.py \
--out_dir "out/lora/stablelm-base-alpha-3b/"
```

Then, we can merge the LoRA weights with the checkpoint model using the `merge_lora.py` script as shown below.
Simply pass in the checkpoint directory which is the result of the finetuning script:
This code will produce a `lit_model.pth.lora` file in the specified output directory, containing only the LoRA weights. To merge these LoRA weights with the original model checkpoint, you can use the `merge_lora.py` script as follows:

```bash
python scripts/merge_lora.py \
--checkpoint_dir "out/lora/stablelm-base-alpha-3b/final"
```

After merging, we can use the `litgpt/generate/base.py` or `litgpt/chat/base.py` file for inference using the new checkpoint file.
Executing this script results in the creation of a full `lit_model.pth` checkpoint that can be used with the `generate/base.py` or `chat/base.py` scripts for inference:

```bash
python litgpt/generate/base.py \
Expand Down

0 comments on commit 4c43c09

Please sign in to comment.