diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 2bedfa743e..1e1120d214 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -43,7 +43,7 @@ def merge_lora( fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) - with fabric.init_module(), torch.device("meta"): + with fabric.init_module(): model = GPT(config) lora_path = checkpoint_dir / "lit_model.pth.lora" @@ -52,7 +52,7 @@ def merge_lora( # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) - model.load_state_dict(pretrained_checkpoint, assign=True) + model.load_state_dict(pretrained_checkpoint) merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring