Skip to content

Commit

Permalink
[FSDP2] Used CommDebugMode in grad acc test (pytorch#126067)
Browse files Browse the repository at this point in the history
+9/-27 lines -- very nice :)

Pull Request resolved: pytorch#126067
Approved by: https://github.com/wanchaol
  • Loading branch information
awgu authored and pytorchmergebot committed May 14, 2024
1 parent 20aa7cc commit 2e4d011
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
register_fsdp_forward_method,
)
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed._tensor.debug.comm_mode import CommDebugMode
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
apply_activation_checkpointing,
Expand All @@ -42,7 +43,6 @@
FSDPTestMultiThread,
MLP,
patch_all_gather,
patch_all_reduce,
patch_reduce_scatter,
test_compiled_fsdp,
)
Expand All @@ -59,6 +59,8 @@
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir

c10d_ops = torch.ops.c10d


class TestFullyShardForwardInputs(FSDPTestMultiThread):
@property
Expand Down Expand Up @@ -716,32 +718,9 @@ def _test_gradient_accumulation(
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

# TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
orig_all_gather = dist.all_gather_into_tensor
orig_reduce_scatter = dist.reduce_scatter_tensor
orig_all_reduce = dist.all_reduce
all_gather_count, reduce_scatter_count, all_reduce_count = 0, 0, 0

def all_gather_with_count(*args, **kwargs):
nonlocal all_gather_count
all_gather_count += 1
return orig_all_gather(*args, **kwargs)

def reduce_scatter_with_count(*args, **kwargs):
nonlocal reduce_scatter_count
reduce_scatter_count += 1
return orig_reduce_scatter(*args, **kwargs)

def all_reduce_with_count(*args, **kwargs):
nonlocal all_reduce_count
all_reduce_count += 1
return orig_all_reduce(*args, **kwargs)

torch.manual_seed(1) # same on all ranks
for iter_idx in range(5):
with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
reduce_scatter_with_count
), patch_all_reduce(all_reduce_with_count):
with CommDebugMode() as comm_mode:
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
if mode == "all":
Expand Down Expand Up @@ -775,6 +754,11 @@ def all_reduce_with_count(*args, **kwargs):
dist.all_reduce(losses[1]) # partial -> replicated
self.assertEqual(losses[0], losses[1])

comm_counts = comm_mode.get_comm_counts()
all_gather_count = comm_counts[c10d_ops._allgather_base_]
reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
all_reduce_count = comm_counts[c10d_ops.allreduce_]

# Expect one reduce-scatter per MLP plus one for the root's linear
# on the last microbatch
expected_reduce_scatter_count = num_mlps + 1
Expand All @@ -794,7 +778,6 @@ def all_reduce_with_count(*args, **kwargs):
self.assertEqual(all_reduce_count, expected_reduce_scatter_count)
else:
self.assertEqual(all_reduce_count, 0)
reduce_scatter_count = all_reduce_count = 0

# Expect one all-gather per MLP plus one for the root's linear in
# the first microbatch's forward
Expand All @@ -817,7 +800,6 @@ def all_reduce_with_count(*args, **kwargs):
# microbatch forward
expected_all_gather_count += num_mlps * (num_microbatches - 1)
self.assertEqual(all_gather_count, expected_all_gather_count)
all_gather_count = 0

# Average the ref model's gradients over the world size to match
# data parallel semantics
Expand Down

0 comments on commit 2e4d011

Please sign in to comment.