Skip to content

Commit

Permalink
Workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 27, 2024
1 parent ef963f4 commit 7d8d596
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
9 changes: 7 additions & 2 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ def setup_module(self, module: Module) -> Module:
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" DDP transform. Remove the `forward` call before `fabric.setup()`"
)
# modify the reference
cd.fn = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
assert cd.is_module # sanity check
ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
# update the compile data state
cd.fn = ddp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = ddp_module
cd.process_group_for_ddp = ddp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.ddp(module, **self._ddp_kwargs)
Expand Down
9 changes: 7 additions & 2 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,19 @@ def setup_module(self, module: Module) -> Module:
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" FSDP transform. Remove the `forward` call before `fabric.setup()`"
)
# modify the reference
cd.fn = thunder.distributed.fsdp(
assert cd.is_module # sanity check
fsdp_module = thunder.distributed.fsdp(
cd.fn,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
# update the compile data state
cd.fn = fsdp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = fsdp_module
cd.process_group_for_ddp = fsdp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.fsdp(
Expand Down

0 comments on commit 7d8d596

Please sign in to comment.