Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype mismatch when merging LoRA checkpoints #1246

Merged
merged 1 commit into from
Apr 4, 2024

Conversation

awaelchli
Copy link
Contributor

This partially reverts #1189 which added assign=True. The problem is that the base model checkpoints come in a different precision than what we use for finetuning. So if the LoRA weights are stored in bfloat16, but the pretrained weights are stored in float16, we get errors because with assign we don't copy/cast the tensors to the same dtype:

⚡ ~ litgpt merge_lora --checkpoint_dir out/qlora-llama2-7b/final 
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/llm-finetune/litgpt/litgpt/__main__.py", line 129, in main
    fn(**kwargs)
  File "/teamspace/studios/this_studio/llm-finetune/litgpt/litgpt/scripts/merge_lora.py", line 56, in merge_lora
    merge_lora_weights(model)
  File "/teamspace/studios/this_studio/llm-finetune/litgpt/litgpt/lora.py", line 746, in merge_lora_weights
    module.merge()
  File "/teamspace/studios/this_studio/llm-finetune/litgpt/litgpt/lora.py", line 400, in merge
    super().merge()
  File "/teamspace/studios/this_studio/llm-finetune/litgpt/litgpt/lora.py", line 164, in merge
    raise NotImplementedError(
NotImplementedError: Cannot merge the pretrained weights of type torch.float16 and LoRA weights of type torch.bfloat16

@rasbt
Copy link
Collaborator

rasbt commented Apr 4, 2024

Thanks for the fix! I am just thinking about scenarios/ramifications ...

I think when a user says

litgpt finetune lora ... --precision bf16-true

is the intention, most of the time, to have the whole model in bf16, not just the LoRA weights? I think so, right? Because usually they care most about memory efficiency when they use that setting. Your change would basically accomplish that, right?

The alternative is to have a separate precision argument for model and weights and be explicit, like --precision bf16-true casts the model, and lora.precision bf16-true (but then of course it's more work for the user to get it to match, because otherwise merging would not work which would be frustrating.

In short, I think the change it totally good and fine as is and there is no need to further complicate it.

(PS: I think this would also address the issue I had when using Galore on top of LoRA!)

@awaelchli
Copy link
Contributor Author

awaelchli commented Apr 4, 2024

I think if the model was finetuned in precision X, we want the merged checkpoint to be saved in precision X because that is what we trained. That's what we were doing prior to #1246, where the precision we used for training is set in Fabric and then we load the checkpoint into that model for merging.

The meta-device + assign should only be used if the checkpoint already contains the tensors with the desired dtype and device, which was not the case here.

@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2024

Why don't we implement dtype promotion instead? Assigning has the advantage of decreased memory usage which an user had problems with.

The NotImplementedError raised by LoRA merging could be replaced, there's nothing unblocking that

@awaelchli
Copy link
Contributor Author

awaelchli commented Apr 4, 2024

Regardless, #1189 broke what previously worked. In that sense, I vote for correctness > performance in terms of priorities.
If you have an idea how to do the conversion, can you explain it concretely or open a PR?

@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2024

Since we are doing inplace add, we can rely on PyTorch's internal accumulation in fp32 whilst keeping the original pretrained weights dtype. We just need to make sure that the pretrained weights dtype is the expected one

@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2024

@Andrei-Aksionov Is there any specific reason why you added this if statement instead of letting it do type promotion under the hood? https://github.com/Lightning-AI/litgpt/pull/771/files#diff-9278d308e05826a4593d5fb4ddec59041a56291a309f856f4322c834c6fd1555R145-R147

@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2024

Opened #1248 with my proposal. It would be useful if you could try it out on what you are running

@carmocca carmocca merged commit d8dc97e into main Apr 4, 2024
8 checks passed
@carmocca carmocca deleted the bugfix/merge-lora-dtype branch April 4, 2024 16:27
carmocca added a commit that referenced this pull request Apr 4, 2024
@Andrei-Aksionov
Copy link
Collaborator

No, there was no specific reason for it.
Just don't like implicit stuff in general. Decided that it's better to have such an error for the cases that aren't expected and then later decided what to do about it if they occur.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants