From 08c260bc292649e007d1e2964ca225bd9dd61acb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 21 May 2024 16:02:38 -0700 Subject: [PATCH] [pipelining] Test schedules against manual stage (#126735) Added manual stage in test_schedule.py so that we can test various schedules against it. In this file we now have: - test_schedule_with_tracer - test_schedule_with_manual - test_grad_with_tracer - test_grad_with_manual Tested schedules are: - ScheduleGPipe - Schedule1F1B Pull Request resolved: https://github.com/pytorch/pytorch/pull/126735 Approved by: https://github.com/wconstab, https://github.com/H-Huang ghstack dependencies: #126812, #126721 --- test/distributed/pipelining/test_schedule.py | 81 +++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index c1fb6b075f7662..48e7300edd6c63 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -10,6 +10,7 @@ from model_registry import ModelWithKwargs, MultiMLP from torch.distributed.pipelining import ( + ManualPipelineStage, pipeline, PipelineStage, Schedule1F1B, @@ -53,7 +54,7 @@ def setUpClass(cls): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) - def test_ec_backward(self, ScheduleClass): + def test_kwargs_with_tracer(self, ScheduleClass): mod = ModelWithKwargs(d_hid) mod.to(self.device) @@ -100,8 +101,9 @@ def test_ec_backward(self, ScheduleClass): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) - def test_grad(self, ScheduleClass): - mod = MultiMLP(d_hid) + @parametrize("ModelClass", [MultiMLP]) + def test_grad_with_tracer(self, ScheduleClass, ModelClass): + mod = ModelClass(d_hid) mod.to(self.device) ref_mod = copy.deepcopy(mod) @@ -170,6 +172,79 @@ def test_grad(self, ScheduleClass): print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_grad_with_manual(self, ScheduleClass): + full_mod = MultiMLP(d_hid) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Get a submodule, e.g. mlp0 or mlp1 + submod_name = f"mlp{self.rank}" + stage_module = full_mod.get_submodule(submod_name) + # Create a pipeline stage to wrap that submodule + stage = ManualPipelineStage( + stage_module, + self.rank, + self.world_size, + self.device, + chunks, + input_args=x.chunk(chunks)[0], + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) + + # Run + for _ in range(2): + # Zero gradients + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + out = schedule.step(target=target, losses=losses) + else: + schedule.step() + + dist.barrier() + + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + ref_submod = ref_mod.get_submodule(submod_name) + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + instantiate_parametrized_tests(ScheduleTest)