Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smarter thunder.jit decisions #1204

Merged
merged 5 commits into from
Mar 27, 2024
Merged

Smarter thunder.jit decisions #1204

merged 5 commits into from
Mar 27, 2024

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Mar 27, 2024

Adds support for:

Have the user call thunder.jit but still use the strategy

fabric = Fabric(strategy=ThunderFSDPStrategy())
model = MyModel()
model = thunder.jit(model)
model = fabric.setup(model)  # this is now smart enough to know that model was already jitted

PoC:

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

def fwd_loss(m, x):
    return m(x).sum()

model = thunder.jit(model)
model._lc_cd.fn = thunder.distributed.fsdp(model._lc_cd.fn)

out = fwd_loss(model, x)

print(out)
if local_rank == 0:
    print("FN", thunder.last_traces(model)[-1].python())

Have the user compile an arbitrary function that includes the model

def fwd_loss(m, x):
    return m(x).sum()

fabric = Fabric(strategy=ThunderFSDPStrategy(jit=False))
model = MyModel()
model = thunder.jit(fwd_and_loss)
model = fabric.setup(model)
fwd_and_loss(model, ...)

Thunder doesn't support jitting twice here, so the user needs to disable the strategy's jit call since fabric doesn't know anything about fwd_and_loss

PoC:

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

def fwd_loss(m, x):
    return m(x).sum()

fwd_loss = thunder.jit(fwd_loss)
model = thunder.distributed.fsdp(model)

out = fwd_loss(model, x)

print(out)
if local_rank == 0:
    print("FN", thunder.last_traces(fwd_loss)[-1].python())

@carmocca carmocca self-assigned this Mar 27, 2024
@carmocca
Copy link
Contributor Author

DDP is blocked until Lightning-AI/lightning-thunder#94 is resolved

@carmocca carmocca marked this pull request as ready for review March 27, 2024 23:26
@carmocca carmocca requested a review from lantiga as a code owner March 27, 2024 23:26
@carmocca carmocca merged commit a67dd5c into main Mar 27, 2024
8 checks passed
@carmocca carmocca deleted the carmocca/customizable-jit branch March 27, 2024 23:42
rasbt pushed a commit that referenced this pull request Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants