diff --git a/.gitignore b/.gitignore index a2e84c57ad..d16dd90cdd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,7 @@ checkpoints out wandb events.out.tfevents* + +# test artifacts from tests/test_readme.py +tests/custom_finetuning_dataset.json +tests/custom_texts \ No newline at end of file diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index aff59daef4..845acbc405 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -43,16 +43,23 @@ def merge_lora( fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) - with fabric.init_module(): + with fabric.init_module(), torch.device("meta"): model = GPT(config) + # we don't care about these to perform merging + model.cos = None + model.sin = None lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) lora_checkpoint = torch.load(str(lora_path), mmap=True) + lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint) # Merge LoRA weights into the base model - pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) - model.load_state_dict(pretrained_checkpoint) + pretrained_checkpoint.update(lora_checkpoint) + model.load_state_dict(pretrained_checkpoint, assign=True) + # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype + lora_dtype = next(iter(lora_checkpoint.values())).dtype + model.to(dtype=lora_dtype, device="cpu") merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring