Skip to content

Commit

Permalink
duh
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Apr 4, 2024
1 parent 7d37425 commit fffcce7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def merge_lora(
lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
pretrained_checkpoint = {k: v.to(dtype=lora_dtype) for k, v in pretrained_checkpoint.items()}

# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(pretrained_checkpoint, assign=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
model.to(dtype=lora_dtype)
merge_lora_weights(model)

# Remove LoRA parameters and the LoRA linear substring
Expand Down

0 comments on commit fffcce7

Please sign in to comment.