Skip to content

Commit

Permalink
Move pretrained dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Apr 4, 2024
1 parent 4267658 commit 7d37425
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def merge_lora(
lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
pretrained_checkpoint = {k: v.to(dtype=lora_dtype) for k, v in pretrained_checkpoint.items()}

# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
Expand Down

0 comments on commit 7d37425

Please sign in to comment.