Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding DDP/FSDP transform after JITting does not work #94

Open
carmocca opened this issue Mar 27, 2024 · 8 comments
Open

adding DDP/FSDP transform after JITting does not work #94

carmocca opened this issue Mar 27, 2024 · 8 comments
Labels
bug Something isn't working distributed help wanted Extra attention is needed

Comments

@carmocca
Copy link
Contributor

carmocca commented Mar 27, 2024

🐛 Bug

The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: Lightning-AI/litgpt#1204

The objective is that fsdp|ddp can be applied after the thunder.jit call.

It works with FSDP, but not with DDP where it fails with:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]:     out = tmodel(x)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]:     res = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]:     computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]:     bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]:     updated_bwd_trace = visitor_transform(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]:     visit_type = visit(bsym)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]:     self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]:     self._maybe_allreduce(bucket, group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]:     self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]:     result = self.meta(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]:     result = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]:     utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]:     check(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:     raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>

To Reproduce

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)

out = tmodel(x)

if local_rank == 0:
    print(thunder.last_backward_traces(tmodel)[-1].python())

torchrun --nproc-per-node 2 bug.py

cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23

@carmocca carmocca added bug Something isn't working help wanted Extra attention is needed labels Mar 27, 2024
@carmocca
Copy link
Contributor Author

I can work around this by setting

tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp

since thunder gets this information at jit() time:

# We set the process_group_for_ddp attribute on the module when
# thunder.distributed.ddp(module) is called.
self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)

So my question is: could we delay accessing this attribute until the function runs?

@t-vi t-vi changed the title Replacing CompileData.fn fails with DDP DDP/FSDP after JIT does not work Mar 28, 2024
@t-vi t-vi changed the title DDP/FSDP after JIT does not work adding DDP/FSDP transform after JITting does not work Mar 28, 2024
@t-vi
Copy link
Collaborator

t-vi commented Mar 28, 2024

TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!".

That said, we can talk about distributed-after-jit. The obstacles are:

  • Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.
  • The syncs inserted can do funny things with tensor shapes, so applying this after the prologue is generated (i.e. after "jit" will have us needing to transform the prologue - this is on our roadmap for other transforms, but not entirely trivial).

I'll chat you up for understanding the need better.

@mruberry
Copy link
Collaborator

mruberry commented Apr 1, 2024

triage review:

let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp?

do we need to support transforms that change the original module, or maybe produce a new module?

@carmocca
Copy link
Contributor Author

carmocca commented Apr 3, 2024

Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.

I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added.

I'm probably misunderstanding how the interpreter works. How can the prologues be generated at jit time if we don't have any input tensors for which to check shapes? I thought this would only happen on the first call

@IvanYashchuk
Copy link
Collaborator

FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.

  • We should raise an error when ThunderModule is passed to FSDP or DDP calls suggesting the correct thing
  • Why do you expect ddp(jit(model)) to work and is it more important to support than jit(ddp(model))?

@carmocca
Copy link
Contributor Author

We still need to support jit(ddp(model)), as this is basically what happens whenever you jit a function and not the model.

What I'm advocating for is something like jit(ddp(undo_jit(jit(model)))

Where undo_jit is currently the hack that I describe in the top-post.

Allowing this is convenient because then the user can control the innermost jit(model) call but the framework (fabric, trainer) can control the transforms applied to the model and how they interact with each other if there are more than one.

@IvanYashchuk
Copy link
Collaborator

I know nothing about Lightning. Do you want to allow users to do jit(model) and then inside Lightning, you apply either DDP or FSDP call to a given model? FSDP is now working, right? You need something that unwraps the jit call. Have you tried using __wrapped__? thunder.jit uses functools.wraps here:

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Apr 16, 2024

One thing (probably tangential) I was wondering, why is process_group_for_ddp an attribute for CompileData?

# We set the process_group_for_ddp attribute on the module when
# thunder.distributed.ddp(module) is called.
self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)

I think it would make sense to make it a property. Cause, if a scenario comes where we have to update CompileData.fn, then we might miss updating these corresponding attributes. (This change could potentially also fix the issue)

diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:
 
         self.is_module = isinstance(self.fn, torch.nn.Module)
 
-        # We set the process_group_for_ddp attribute on the module when
-        # thunder.distributed.ddp(module) is called.
-        self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
         #
         # Possibly processes the function
         #
@@ -232,6 +228,12 @@ class CompileData:
 
         assert disable_preprocessing, "please use thunder.compile if you need preprocessing"
 
+    @property
+    def process_group_for_ddp(self):
+        # We set the process_group_for_ddp attribute on the module when
+        # thunder.distributed.ddp(module) is called.
+        return getattr(self.fn, "process_group_for_ddp", None)
+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants