-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading checkpoint before fabric.setup(model) gets abnormal loss when using fabric.init_module() #20490
Comments
Thank you @kobenaxie for the issue We'll take a close look shortly, both here and Lightning-AI/litgpt#1868 |
Hello @kobenaxie This code addresses the issue of high memory consumption when loading weights into a model. In the traditional approach, two sets of weights exist simultaneously:
model = GPT(...) # model with random weights
weights = torch.load(...) # pretrained weights
model.load_state_dict(weight) To mitigate this, the process is split into multiple steps:
This happens because the model is materialized with random weights, as So, when you run When you commented out The loss value provides a hint. import torch
import torch.nn.functional as F
batch_size = 32
vocab_size = 151_643
logits = torch.randn(batch_size, vocab_size)
targets = torch.randint(0, 2, (batch_size, vocab_size)).float()
loss_ce = F.cross_entropy(logits, targets.argmax(dim=1))
print(loss_ce)
>> tensor(12.1440) |
Amazing explanation @Andrei-Aksionov , this one is great for the docs! |
Hi @lantiga @Andrei-Aksionov , thank you for your amazing explanation, and I have another question, when wrapping
And then create model
If I want to load pretrained checkpoint of
|
Bug description
Init model with
fabric.init_module(True)
and load checkpoint aftermodel = fabric.setup(model)
, the training loss is normalbut when loading checkpoint before
model = fabric.setup(model)
, get loss much largerAnother phenomenon is that, if not using
fabric.init_module()
, I can get normal loss when loading checkpoint beforefabric.setup(model)
,So how to load hf models converted by
litgpt.scripts.convert_hf_checkpoint
in a correct way?What version are you seeing the problem on?
v2.4
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: