-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding DDP/FSDP transform after JITting does not work #94
Comments
I can work around this by setting tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp since lightning-thunder/thunder/common.py Lines 224 to 226 in 94c9494
So my question is: could we delay accessing this attribute until the function runs? |
CompileData.fn
fails with DDP
TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!". That said, we can talk about distributed-after-jit. The obstacles are:
I'll chat you up for understanding the need better. |
triage review: let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp? do we need to support transforms that change the original module, or maybe produce a new module? |
I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added. I'm probably misunderstanding how the interpreter works. How can the prologues be generated at |
FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.
|
We still need to support What I'm advocating for is something like Where Allowing this is convenient because then the user can control the innermost |
I know nothing about Lightning. Do you want to allow users to do lightning-thunder/thunder/__init__.py Line 642 in 6c64fb9
|
One thing (probably tangential) I was wondering, why is lightning-thunder/thunder/common.py Lines 221 to 223 in 6c64fb9
I think it would make sense to make it a property. Cause, if a scenario comes where we have to update diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:
self.is_module = isinstance(self.fn, torch.nn.Module)
- # We set the process_group_for_ddp attribute on the module when
- # thunder.distributed.ddp(module) is called.
- self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
#
# Possibly processes the function
#
@@ -232,6 +228,12 @@ class CompileData:
assert disable_preprocessing, "please use thunder.compile if you need preprocessing"
+ @property
+ def process_group_for_ddp(self):
+ # We set the process_group_for_ddp attribute on the module when
+ # thunder.distributed.ddp(module) is called.
+ return getattr(self.fn, "process_group_for_ddp", None)
+ |
🐛 Bug
The snippet below looks hacky, but it's how I'm approaching support for having the user control the
thunder.jit
call outside of Fabric: Lightning-AI/litgpt#1204The objective is that
fsdp|ddp
can be applied after thethunder.jit
call.It works with FSDP, but not with DDP where it fails with:
To Reproduce
torchrun --nproc-per-node 2 bug.py
cc @carmocca @awaelchli @crcrpar @kshitij12345 since you fixed a similar issue in #23
The text was updated successfully, but these errors were encountered: