Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back meta-device assign=True loading in merge_lora #1250

Merged
merged 8 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ checkpoints
out
wandb
events.out.tfevents*

# test artifacts from tests/test_readme.py
tests/custom_finetuning_dataset.json
tests/custom_texts
13 changes: 10 additions & 3 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,23 @@ def merge_lora(
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)

with fabric.init_module():
with fabric.init_module(), torch.device("meta"):
model = GPT(config)
# we don't care about these to perform merging
model.cos = None
model.sin = None

lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint)

# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(pretrained_checkpoint)
pretrained_checkpoint.update(lora_checkpoint)
model.load_state_dict(pretrained_checkpoint, assign=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
model.to(dtype=lora_dtype, device="cpu")
merge_lora_weights(model)

# Remove LoRA parameters and the LoRA linear substring
Expand Down