-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype mismatch when merging LoRA checkpoints #1246
Conversation
Thanks for the fix! I am just thinking about scenarios/ramifications ... I think when a user says litgpt finetune lora ... --precision bf16-true is the intention, most of the time, to have the whole model in bf16, not just the LoRA weights? I think so, right? Because usually they care most about memory efficiency when they use that setting. Your change would basically accomplish that, right? The alternative is to have a separate precision argument for model and weights and be explicit, like In short, I think the change it totally good and fine as is and there is no need to further complicate it. (PS: I think this would also address the issue I had when using Galore on top of LoRA!) |
I think if the model was finetuned in precision X, we want the merged checkpoint to be saved in precision X because that is what we trained. That's what we were doing prior to #1246, where the precision we used for training is set in Fabric and then we load the checkpoint into that model for merging. The meta-device + assign should only be used if the checkpoint already contains the tensors with the desired dtype and device, which was not the case here. |
Why don't we implement dtype promotion instead? Assigning has the advantage of decreased memory usage which an user had problems with. The NotImplementedError raised by LoRA merging could be replaced, there's nothing unblocking that |
Regardless, #1189 broke what previously worked. In that sense, I vote for correctness > performance in terms of priorities. |
Since we are doing inplace add, we can rely on PyTorch's internal accumulation in fp32 whilst keeping the original pretrained weights dtype. We just need to make sure that the pretrained weights dtype is the expected one |
@Andrei-Aksionov Is there any specific reason why you added this |
Opened #1248 with my proposal. It would be useful if you could try it out on what you are running |
No, there was no specific reason for it. |
This partially reverts #1189 which added
assign=True
. The problem is that the base model checkpoints come in a different precision than what we use for finetuning. So if the LoRA weights are stored in bfloat16, but the pretrained weights are stored in float16, we get errors because with assign we don't copy/cast the tensors to the same dtype: