diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 2bedfa743e..b589fba685 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -49,6 +49,9 @@ def merge_lora( lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) lora_checkpoint = torch.load(str(lora_path), mmap=True) + # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype + lora_dtype = next(iter(lora_checkpoint.values())).dtype + pretrained_checkpoint = {k: v.to(dtype=lora_dtype) for k, v in pretrained_checkpoint.items()} # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))