From 7d8d5969255d0fce27966204c364ef5dcdfb8e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 28 Mar 2024 00:24:38 +0100 Subject: [PATCH] Workaround --- extensions/thunder/strategies/thunder_ddp.py | 9 +++++++-- extensions/thunder/strategies/thunder_fsdp.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index d6b382918b..4efbe27c60 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -136,8 +136,13 @@ def setup_module(self, module: Module) -> Module: "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" " DDP transform. Remove the `forward` call before `fabric.setup()`" ) - # modify the reference - cd.fn = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs) + assert cd.is_module # sanity check + ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs) + # update the compile data state + cd.fn = ddp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = ddp_module + cd.process_group_for_ddp = ddp_module.process_group_for_ddp return module else: module = thunder.distributed.ddp(module, **self._ddp_kwargs) diff --git a/extensions/thunder/strategies/thunder_fsdp.py b/extensions/thunder/strategies/thunder_fsdp.py index a9ab6f8d88..d4e60c0085 100644 --- a/extensions/thunder/strategies/thunder_fsdp.py +++ b/extensions/thunder/strategies/thunder_fsdp.py @@ -167,14 +167,19 @@ def setup_module(self, module: Module) -> Module: "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" " FSDP transform. Remove the `forward` call before `fabric.setup()`" ) - # modify the reference - cd.fn = thunder.distributed.fsdp( + assert cd.is_module # sanity check + fsdp_module = thunder.distributed.fsdp( cd.fn, device=self.root_device, sharding_strategy=self.sharding_strategy, bucketing_strategy=self.bucketing_strategy, **self._fsdp_kwargs, ) + # update the compile data state + cd.fn = fsdp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = fsdp_module + cd.process_group_for_ddp = fsdp_module.process_group_for_ddp return module else: module = thunder.distributed.fsdp(