Skip to content

Commit

Permalink
Merge branch 'main' into minimal-python-example
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored May 13, 2024
2 parents bcf03b8 + 62a491c commit 1b87dbc
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 15 deletions.
2 changes: 0 additions & 2 deletions extensions/thunder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,6 @@ python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, t
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1
```

Gradient accumulation is disabled in the FSDP setting because Thunder does not support skipping the backward synchronization yet.

`--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601

The CUDA devices are all NVIDIA A100-SXM4-40GB.
Expand Down
42 changes: 34 additions & 8 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._backward_sync_control = _DDPBackwardSyncControl()
self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()
self._ddp_kwargs = kwargs

@property
Expand Down Expand Up @@ -194,32 +194,58 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank


class _DDPBackwardSyncControl(_BackwardSyncControl):
class _ThunderDataParalellBackwardSyncControl(_BackwardSyncControl):
def __init__(self):
self._enabled = False

@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
if not getattr(module, "use_ddp", False):
"""
In Thunder, we cannot use ``module.no_sync()`` because reduction happens at the end of the context manager.
It assumes that the user will reuse it across all gradient accumulation iterations:
.. code-block:: python
with model.no_sync():
for _ in range(len(gradient_accumulation_iters)):
fwd()
bwd() # uses no-sync-backward trace
fwd()
bwd() # uses regular-backward trace
However, Fabric is designed to the context manager every iteration:
.. code-block:: python
for i in range(iters):
is_accumulating = (i + 1) % gradient_accumulation_iters != 0
ctx = model.no_sync() if is_accumulating else nullcontext()
with ctx:
fwd()
bwd()
So we need to be smart about when to sync grads based on the ``enabled`` value.
More info in https://github.com/Lightning-AI/lit-thunder-LEGACY/issues/2085
"""
if not getattr(module, "use_ddp", False) and not getattr(module, "use_fsdp", False):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is applied DDP."
f" `{self.__class__.__name__}.no_backward_sync` is applied DDP or FSDP."
f" Got: {module.__class__.__name__}."
)

# see https://github.com/Lightning-AI/lightning-thunder/issues/2085
# for why we cannot just return `module.no_sync()`
from thunder.distributed import skip_data_parallel_grad_sync

previous, self._enabled = self._enabled, enabled
if enabled:
return skip_data_parallel_grad_sync()
if not enabled and previous:
return _AllReduceGradsContextManager(module)
return _SyncGradsContextManager(module)
return nullcontext()


class _AllReduceGradsContextManager:
class _SyncGradsContextManager:
def __init__(self, module: Module) -> None:
self._module = module

Expand Down
2 changes: 2 additions & 0 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override
from extensions.thunder.strategies.thunder_ddp import _ThunderDataParalellBackwardSyncControl

if TYPE_CHECKING:
from thunder import Executor
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
self.jit = jit
self.executors = executors
self._state_dict_type = state_dict_type
self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()
self._fsdp_kwargs = kwargs

@property
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def restore_default_dtype():

@pytest.fixture(autouse=True)
def destroy_process_group():
yield

import torch.distributed

if torch.distributed.is_available() and torch.distributed.is_initialized():
Expand Down
23 changes: 18 additions & 5 deletions tests/test_thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
sys.path.append(str(wd))

from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy
from extensions.thunder.strategies.thunder_fsdp import ThunderFSDPStrategy


@RunIf(thunder=True)
Expand All @@ -20,15 +21,22 @@ def test_thunder_strategy_input_parsing():


@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("strategy", ["ddp", "thunder_ddp"])
def test_no_backward_sync(strategy):
if strategy == "thunder_ddp":
@pytest.mark.parametrize("choice", ["ddp", "thunder_ddp", "fsdp", "thunder_fsdp"])
def test_no_backward_sync(choice):
if choice == "thunder_ddp":
strategy = ThunderDDPStrategy()
elif choice == "thunder_fsdp":
strategy = ThunderFSDPStrategy()
else:
strategy = choice

fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy)
fabric.launch()

model = torch.nn.Linear(1, 1, bias=False, device=fabric.device)
# account for sharding in the case of FSDP
out_features = 1 if "ddp" in choice else fabric.world_size

model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device)
x = torch.randn(1, 1, device=fabric.device)
model = fabric.setup(model)

Expand All @@ -38,7 +46,7 @@ def test_no_backward_sync(strategy):

with fabric.no_backward_sync(model, enabled):
y = model(x)
y.backward()
fabric.backward(y.sum())
if not enabled:
# Math for the first 3 iters
#
Expand All @@ -51,8 +59,13 @@ def test_no_backward_sync(strategy):
# ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9
# ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^
# rank0 rank1 allreduce1 rank0 rank1 allreduce2
assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape
assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)
assert not hasattr(model.weight, "_thunder_fsdp_unsharded_grad")
model.weight.grad = None
elif choice == "thunder_fsdp":
assert model.weight._thunder_fsdp_unsharded_grad.shape == (2, 1)
assert model.weight.grad is None


@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
Expand Down

0 comments on commit 1b87dbc

Please sign in to comment.