From 3b17eb11aff62b953e2ad4da27b1c9f68fed8581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 6 May 2024 18:46:56 +0200 Subject: [PATCH] Fixes --- .gitignore | 4 ++++ litgpt/scripts/merge_lora.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a2e84c57ad..d16dd90cdd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,7 @@ checkpoints out wandb events.out.tfevents* + +# test artifacts from tests/test_readme.py +tests/custom_finetuning_dataset.json +tests/custom_texts \ No newline at end of file diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 41843cb6c1..845acbc405 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -45,17 +45,21 @@ def merge_lora( with fabric.init_module(), torch.device("meta"): model = GPT(config) + # we don't care about these to perform merging + model.cos = None + model.sin = None lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) lora_checkpoint = torch.load(str(lora_path), mmap=True) + lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint) # Merge LoRA weights into the base model - pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) + pretrained_checkpoint.update(lora_checkpoint) model.load_state_dict(pretrained_checkpoint, assign=True) # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype lora_dtype = next(iter(lora_checkpoint.values())).dtype - model.to(dtype=lora_dtype) + model.to(dtype=lora_dtype, device="cpu") merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring