Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading checkpoint before fabric.setup(model) gets abnormal loss when using fabric.init_module() #20490

Open
kobenaxie opened this issue Dec 11, 2024 · 4 comments
Labels
bug Something isn't working ver: 2.4.x

Comments

@kobenaxie
Copy link

Bug description

Init model with fabric.init_module(True) and load checkpoint after model = fabric.setup(model), the training loss is normal

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
model = fabric.setup(model)
load_checkpoint(fabric, model, checkpoint_path)

step = 1 | loss train: 0.8448048233985901
step = 2 | loss train: 1.3229767084121704
step = 3 | loss train: 1.2647839784622192
step = 4 | loss train: 1.287076711654663
step = 5 | loss train: 1.0357563495635986

but when loading checkpoint before model = fabric.setup(model), get loss much larger

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)

step = 1 | loss train: 12.027938842773438
step = 2 | loss train: 12.051375389099121
step = 3 | loss train: 12.112957954406738
step = 4 | loss train: 12.08558177947998
step = 5 | loss train: 12.089488983154297

Another phenomenon is that, if not using fabric.init_module(), I can get normal loss when loading checkpoint before fabric.setup(model),

# with fabric.init_module(empty_init=(fabric.world_size > 1)):
if True:
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)

step = 1 | loss train: 0.8447667956352234
step = 2 | loss train: 1.3229438066482544
step = 3 | loss train: 1.2663335800170898
step = 4 | loss train: 1.2902932167053223
step = 5 | loss train: 1.035811185836792

So how to load hf models converted by litgpt.scripts.convert_hf_checkpoint in a correct way?

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from pathlib import Path

import torch
import lightning as L
from lightning.fabric.strategies import FSDPStrategy

from litgpt.args import TrainArgs
from litgpt.config import Config
from litgpt.model import GPT, Block
from litgpt.data import Alpaca2k
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
    chunked_cross_entropy,
    load_checkpoint,
    num_parameters,
    get_default_supported_precision,
)


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
    # linear warmup followed by cosine annealing
    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def main(
    checkpoint_dir: Path,
    devices: int = 8,
    num_nodes: int = 1,
    precision: str = "bf16-true",
    seed: int = 1337,
) -> None:
    torch.set_float32_matmul_precision("high")

    train_args = TrainArgs(
        save_interval = 1000,
        log_interval = 1,
        global_batch_size = 64,
        micro_batch_size = 4,
        lr_warmup_steps = 1000,
        epochs = 10,
        max_steps = 10000,
    )

    strategy = FSDPStrategy(
        auto_wrap_policy={Block},
        activation_checkpointing_policy={Block},
        state_dict_type="full",
        limit_all_gathers=True,
        cpu_offload=False,
    )
    
    fabric = L.Fabric(
        accelerator="cuda",
        devices=devices,
        num_nodes=num_nodes,
        strategy=strategy,
        precision=precision,
    )
    fabric.launch()
    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)
    
    dataset = Alpaca2k()
    tokenizer = Tokenizer(str(checkpoint_dir))
    dataset.connect(tokenizer, batch_size=train_args.micro_batch_size, max_seq_length=512)
    with fabric.rank_zero_first():
        dataset.prepare_data()
    dataset.setup()
    dataloader = dataset.train_dataloader()
    dataloader = fabric.setup_dataloaders(dataloader)

    checkpoint_path = str(checkpoint_dir / "lit_model.pth")
    config = Config.from_file(checkpoint_dir / "model_config.yaml")
    with fabric.init_module(empty_init=(fabric.world_size > 1)):
        model = GPT(config)
    fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
    # load_checkpoint(fabric, model, checkpoint_path)
    model = fabric.setup(model)
    load_checkpoint(fabric, model, checkpoint_path)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
    optimizer = fabric.setup_optimizers(optimizer)
    scheduler = get_lr_scheduler(optimizer, warmup_steps=train_args.lr_warmup_steps, max_steps=train_args.max_steps)

    model.train()
    for epoch in range(train_args.epochs):
        for step, batch in enumerate(dataloader, 1):
            input, target = batch["input_ids"], batch["labels"]
            logits = model(input)
            loss = chunked_cross_entropy(logits[..., :-1, :], target[..., 1:])
            fabric.backward(loss)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            fabric.print(f"{step = } | loss train: {loss.detach().item()}")


if __name__ == "__main__":
    checkpoint_dir = Path("./Qwen2.5-1.5B/")

    main(checkpoint_dir)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4.1):
#- Python version (e.g., 3.10):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:12.1
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

@kobenaxie kobenaxie added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 11, 2024
@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Dec 11, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

Thank you @kobenaxie for the issue

We'll take a close look shortly, both here and Lightning-AI/litgpt#1868

@Andrei-Aksionov
Copy link
Contributor

Hello @kobenaxie

This code addresses the issue of high memory consumption when loading weights into a model.

In the traditional approach, two sets of weights exist simultaneously:

  1. Randomly initialized weights when the model is created.
  2. Pretrained weights that need to be loaded into the model.
model = GPT(...) # model with random weights
weights = torch.load(...) # pretrained weights
model.load_state_dict(weight)

To mitigate this, the process is split into multiple steps:

  1. Model creation on a meta device:
    Using fabric.init_module, the model is created on a meta device. On this device, memory usage is minimal because the weight matrices remain "empty" until explicitly materialized. (Refer to the meta device documentation).

  2. Target device setup:
    The fabric.setup(model) call specifies the target device (e.g., GPU) where the model will be placed.

  3. Loading pretrained weights:
    Finally, load_checkpoint(fabric, model, checkpoint_path) loads the pretrained weights into the model, materializing it on the target device with minimal memory overhead.


but when loading checkpoint before model = fabric.setup(model), get loss

This happens because the model is materialized with random weights, as load_checkpoint was called before fabric.setup for the model on meta device.
load_checkpoint function uses lazy_load from PyTorch that cannot do materialization.

So, when you run fabric.init_module (placing on meta device) and then load_checkpoint, nothing really happens here, the model stays on meta device. And when the model is materialized on the target device, weights values are totally random.

When you commented out fabric.init_module the model was created on a CPU with random weights, then load_checkpoint loaded pretrained weights into it and fabric.setup moved the model to the target device.

The loss value provides a hint.
With a vocabulary size of approximately 151k (for Qwen2.5-1.5B) and randomly initialized weights, the expected loss is around 12.

import torch
import torch.nn.functional as F

batch_size = 32
vocab_size = 151_643
logits = torch.randn(batch_size, vocab_size)
targets = torch.randint(0, 2, (batch_size, vocab_size)).float()

loss_ce = F.cross_entropy(logits, targets.argmax(dim=1))
print(loss_ce)

>> tensor(12.1440)

@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

Amazing explanation @Andrei-Aksionov , this one is great for the docs!

@kobenaxie
Copy link
Author

Hi @lantiga @Andrei-Aksionov , thank you for your amazing explanation, and I have another question, when wrapping GPT like

class Model(torch.nn.Module):
    def __init__(
        self,
        gpt: torch.nn.Module,
        module2: torch.nn.Module,
    ) -> None:
        super().__init__()
        self.gpt = gpt
        self.module2 = module2

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.gpt(input_ids)
        y = self.module2(input_ids)
        ...

And then create model

with fabric.init_module(True):
    model = Model(GPT(config), torch.nn.Linear())
model = fabric.setup(model)

If I want to load pretrained checkpoint of GPT, what should I do ? I tried load_checkpoint(fabric, model.gpt, gpt_checkpoint_path), but failed with error

[rank4]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank4]: RuntimeError: Error(s) in loading state_dict for GPT:
[rank4]:        size mismatch for lm_head.weight: copying a param with shape torch.Size([151936, 1536]) from checkpoint, the shape in current model is torch.Size([0]).
[rank4]:        size mismatch for transformer.wte.weight: copying a param with shape torch.Size([151936, 1536]) from checkpoint, the shape in current model is torch.Size([58343616]).
[rank4]:        size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([0]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

3 participants