From 7d374256373a38ddf56e04d1bbf10c93703b4c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 4 Apr 2024 18:58:02 +0200 Subject: [PATCH] Move pretrained dtype --- litgpt/scripts/merge_lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 2bedfa743e..b589fba685 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -49,6 +49,9 @@ def merge_lora( lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) lora_checkpoint = torch.load(str(lora_path), mmap=True) + # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype + lora_dtype = next(iter(lora_checkpoint.values())).dtype + pretrained_checkpoint = {k: v.to(dtype=lora_dtype) for k, v in pretrained_checkpoint.items()} # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))