diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 38f60127f0..d142385011 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -161,7 +161,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") # Save the final LoRA checkpoint at the end of training - save_path = out_dir / "final" / "lit_model.pth" + save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, save_path) if fabric.global_rank == 0: @@ -262,8 +262,9 @@ def fit( metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() + if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: - checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth" + checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) if fabric.global_rank == 0: diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 45627cad93..3490ecd6a8 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -20,8 +20,8 @@ def merge_lora( ) -> None: """Merges the LoRA weights with the base model. See `litgpt/finetune/lora.py`. - Merging happens in-place in the checkpoint directory that is given as input. It also saves - a backup file `lit_model.pth.lora` of the trained LoRA weights in case you still need it later. + Creates a new `lit_model.pth` file by merging the LoRA weights (`lit_model.pth.lora`) + with the original checkpoint weights. Args: checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of @@ -33,10 +33,10 @@ def merge_lora( precision: Optional precision setting to instantiate the model weights in. By default, this will automatically be inferred from the metadata in the given `checkpoint_dir` directory. """ - check_valid_checkpoint_dir(checkpoint_dir) + check_valid_checkpoint_dir(checkpoint_dir, lora=True) if pretrained_checkpoint_dir is not None: check_valid_checkpoint_dir(pretrained_checkpoint_dir) - if (checkpoint_dir / "lit_model.pth.lora").is_file(): + if (checkpoint_dir / "lit_model.pth").is_file(): print("LoRA weights have already been merged in this checkpoint.") return @@ -49,7 +49,7 @@ def merge_lora( with fabric.init_module(empty_init=True): model = GPT(config) - lora_path = checkpoint_dir / "lit_model.pth" + lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = lazy_load(pretrained_checkpoint_dir / "lit_model.pth") lora_checkpoint = lazy_load(lora_path) @@ -60,15 +60,10 @@ def merge_lora( # Remove LoRA parameters and the LoRA linear substring state_dict = {k.replace("linear.", ""): v for k, v in model.state_dict().items() if not lora_filter(k, v)} - save_path = checkpoint_dir / "lit_model.pth.merged" + save_path = checkpoint_dir / "lit_model.pth" torch.save(state_dict, save_path) - # Make a backup of the LoRA weights (they are only a few MBs) - os.rename(checkpoint_dir / "lit_model.pth", checkpoint_dir / "lit_model.pth.lora") - os.rename(checkpoint_dir / "lit_model.pth.merged", checkpoint_dir / "lit_model.pth") - fabric.print(f"Saved merged weights to {str(checkpoint_dir / 'lit_model.pth')!r}") - fabric.print(f"A backup of the old LoRA weights is in {str(checkpoint_dir / 'lit_model.pth.lora')!r}") def load_lora_metadata(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path, Optional[str]]: diff --git a/litgpt/utils.py b/litgpt/utils.py index c74f7f02f1..500c8beaa0 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -47,9 +47,10 @@ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> i return total -def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: +def check_valid_checkpoint_dir(checkpoint_dir: Path, lora: bool = False) -> None: + model_filename = "lit_model.pth.lora" if lora else "lit_model.pth" files = { - "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + model_filename: (checkpoint_dir / model_filename).is_file(), "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (checkpoint_dir / "tokenizer.model").is_file(), diff --git a/tests/test_lora.py b/tests/test_lora.py index c9cf3fc79e..cae836bd07 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -219,7 +219,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { - "lit_model.pth", + "lit_model.pth.lora", "lit_config.json", "tokenizer_config.json", "tokenizer.json", diff --git a/tests/test_merge_lora.py b/tests/test_merge_lora.py index fee5d02b27..6855b55599 100644 --- a/tests/test_merge_lora.py +++ b/tests/test_merge_lora.py @@ -14,7 +14,6 @@ import yaml -@RunIf(skip_windows=True) # PermissionError in os.rename on Windows @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) def test_merge_lora(tmp_path, fake_checkpoint_dir): from litgpt.lora import GPT as LoRAGPT @@ -26,6 +25,7 @@ def test_merge_lora(tmp_path, fake_checkpoint_dir): lora_checkpoint_dir = tmp_path / "lora" shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir) shutil.copytree(fake_checkpoint_dir, lora_checkpoint_dir) + (lora_checkpoint_dir / "lit_model.pth").unlink() # should not already exist shutil.rmtree(tmp_path / "checkpoints") # Create a fake pretrained checkpoint @@ -42,7 +42,7 @@ def test_merge_lora(tmp_path, fake_checkpoint_dir): lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs) state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)} assert len(state_dict) == 6 - torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth") + torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth.lora") hparams = dict(checkpoint_dir=str(pretrained_checkpoint_dir), **lora_kwargs) with open(lora_checkpoint_dir / "hyperparameters.yaml", "w") as file: yaml.dump(hparams, file) diff --git a/tutorials/finetune_lora.md b/tutorials/finetune_lora.md index ad7b0121a8..3529d2ba19 100644 --- a/tutorials/finetune_lora.md +++ b/tutorials/finetune_lora.md @@ -126,12 +126,9 @@ You can easily train on your own instruction dataset saved in JSON format. ## Merging LoRA Weights -By default, the LoRA weights are kept separate from the checkpoint file to save storage space. -However, you can optionally merge the LoRA weights with the original model checkpoint to create -a new file to optimize inference speeds. (This will improve inference performance -because the weights don't have to be added during runtime.) +Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which has is much smaller than the original model checkpoint to conserve storage space. If desired, there is the option to merge these LoRA weights directly into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed. -Let's assume we finetuned a model using LoRA as follows: +For example, after finetuning a model using LoRA with the following command: ```bash python litgpt/finetune/lora.py \ @@ -140,15 +137,14 @@ python litgpt/finetune/lora.py \ --out_dir "out/lora/stablelm-base-alpha-3b/" ``` -Then, we can merge the LoRA weights with the checkpoint model using the `merge_lora.py` script as shown below. -Simply pass in the checkpoint directory which is the result of the finetuning script: +This code will produce a `lit_model.pth.lora` file in the specified output directory, containing only the LoRA weights. To merge these LoRA weights with the original model checkpoint, you can use the `merge_lora.py` script as follows: ```bash python scripts/merge_lora.py \ --checkpoint_dir "out/lora/stablelm-base-alpha-3b/final" ``` -After merging, we can use the `litgpt/generate/base.py` or `litgpt/chat/base.py` file for inference using the new checkpoint file. +Executing this script results in the creation of a full `lit_model.pth` checkpoint that can be used with the `generate/base.py` or `chat/base.py` scripts for inference: ```bash python litgpt/generate/base.py \