Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 6, 2024
1 parent 2ba3e49 commit 3b17eb1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ checkpoints
out
wandb
events.out.tfevents*

# test artifacts from tests/test_readme.py
tests/custom_finetuning_dataset.json
tests/custom_texts
8 changes: 6 additions & 2 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,21 @@ def merge_lora(

with fabric.init_module(), torch.device("meta"):
model = GPT(config)
# we don't care about these to perform merging
model.cos = None
model.sin = None

lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint)

# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
pretrained_checkpoint.update(lora_checkpoint)
model.load_state_dict(pretrained_checkpoint, assign=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
model.to(dtype=lora_dtype)
model.to(dtype=lora_dtype, device="cpu")
merge_lora_weights(model)

# Remove LoRA parameters and the LoRA linear substring
Expand Down

0 comments on commit 3b17eb1

Please sign in to comment.