From 426765890382636654dc97c51109254c0505d3f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 4 Apr 2024 18:40:28 +0200 Subject: [PATCH] Revert "Fix dtype mismatch when merging LoRA checkpoints (#1246)" This reverts commit d8dc97e4160307a674a6f4bd3ccdacbaa90b4c90. --- litgpt/scripts/merge_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 1e1120d214..2bedfa743e 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -43,7 +43,7 @@ def merge_lora( fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) - with fabric.init_module(): + with fabric.init_module(), torch.device("meta"): model = GPT(config) lora_path = checkpoint_dir / "lit_model.pth.lora" @@ -52,7 +52,7 @@ def merge_lora( # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) - model.load_state_dict(pretrained_checkpoint) + model.load_state_dict(pretrained_checkpoint, assign=True) merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring