Skip to content

Commit

Permalink
[pipelining] Test schedules against manual stage (pytorch#126735)
Browse files Browse the repository at this point in the history
Added manual stage in test_schedule.py so that we can test various schedules against it.

In this file we now have:
- test_schedule_with_tracer
- test_schedule_with_manual
- test_grad_with_tracer
- test_grad_with_manual

Tested schedules are:
- ScheduleGPipe
- Schedule1F1B

Pull Request resolved: pytorch#126735
Approved by: https://github.com/wconstab, https://github.com/H-Huang
ghstack dependencies: pytorch#126812, pytorch#126721
  • Loading branch information
kwen2501 authored and pytorchmergebot committed May 22, 2024
1 parent 6a539e8 commit 08c260b
Showing 1 changed file with 78 additions and 3 deletions.
81 changes: 78 additions & 3 deletions test/distributed/pipelining/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from model_registry import ModelWithKwargs, MultiMLP
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
Schedule1F1B,
Expand Down Expand Up @@ -53,7 +54,7 @@ def setUpClass(cls):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_ec_backward(self, ScheduleClass):
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid)
mod.to(self.device)

Expand Down Expand Up @@ -100,8 +101,9 @@ def test_ec_backward(self, ScheduleClass):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_grad(self, ScheduleClass):
mod = MultiMLP(d_hid)
@parametrize("ModelClass", [MultiMLP])
def test_grad_with_tracer(self, ScheduleClass, ModelClass):
mod = ModelClass(d_hid)
mod.to(self.device)

ref_mod = copy.deepcopy(mod)
Expand Down Expand Up @@ -170,6 +172,79 @@ def test_grad(self, ScheduleClass):
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_grad_with_manual(self, ScheduleClass):
full_mod = MultiMLP(d_hid)
full_mod.to(self.device)

ref_mod = copy.deepcopy(full_mod)
x = torch.randn(batch_size, d_hid, device=self.device)
with torch.no_grad():
y = ref_mod(x)
# Add a small perturbation
target = y + torch.randn(batch_size, d_hid, device=self.device)

loss_fn = torch.nn.MSELoss(reduction="sum")

# Run reference
for _ in range(2):
ref_mod.zero_grad()
ref_out = ref_mod(x)
ref_loss = loss_fn(ref_out, target)
ref_loss.backward()

# Get a submodule, e.g. mlp0 or mlp1
submod_name = f"mlp{self.rank}"
stage_module = full_mod.get_submodule(submod_name)
# Create a pipeline stage to wrap that submodule
stage = ManualPipelineStage(
stage_module,
self.rank,
self.world_size,
self.device,
chunks,
input_args=x.chunk(chunks)[0],
)

# Attach to a schedule
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)

# Run
for _ in range(2):
# Zero gradients
stage_module.zero_grad()
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
out = schedule.step(target=target, losses=losses)
else:
schedule.step()

dist.barrier()

# Last rank checks result
if self.rank == self.world_size - 1:
# Check output
torch.testing.assert_close(out, ref_out)
# Check loss
# Since the reduction used in the loss function above is "sum", we use
# "sum" here to reduce microbatch losses into a single value too.
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)

# Every rank checks gradients
ref_submod = ref_mod.get_submodule(submod_name)
for name, p in stage_module.named_parameters():
ref_p = ref_submod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise


instantiate_parametrized_tests(ScheduleTest)

Expand Down

0 comments on commit 08c260b

Please sign in to comment.