From 81738403a271b2c04bd2e9bba1bea67b0c50cc3f Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sun, 20 Oct 2024 12:49:41 -0700 Subject: [PATCH] [Distributed] Fix extra context on device 0 (#135273) This PR contains multiple fixes for issue https://github.com/pytorch/pytorch/issues/135279: ## First part: Moves the GPU guard (`cudaSetDevice`) before the `currentStreamCaptureStatusMayInitCtx` call. As its name suggests, it May Init Ctx. ## Second part: Even with the above fix, additional contexts are still observed during Work object destruction, e.g. ``` work = dist.all_reduce(tensor, async_op=True) time.sleep(5) <-- no additional context yet del work <-- additional context shows up ``` ### Debug process Chasing it down to destruction of a `Future` object -- a member variable of `Work`. Then further down to the following member of `Future`: ``` std::vector events_; ``` When the `events_` are destroyed, we hit the road down to: https://github.com/pytorch/pytorch/blob/1f3a79379012b408e0375e81fe9205dcba5e34ba/c10/cuda/impl/CUDAGuardImpl.h#L106-L121 When there is no "preset" CUDA context (**which is the case for python garbage collector**), line 112: `c10::cuda::GetDevice(&orig_device)` will set `orig_device` to 0. Then, at line 120, `c10::cuda::SetDevice(orig_device)` will "officially" set the context to device 0 -- **that's where rank 1, 2, ... can create extra context on device 0!** ### Solution This PR adds an explicit destructor to `Future`. In this destructor, destroy each event with a device guard. ## Test Added test_extra_cuda_context, implemented via - `pynvml` (if available), or - memory consumption check. `python test/distributed/test_c10d_nccl.py -k test_extra_cuda_context` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135273 Approved by: https://github.com/fduwjj, https://github.com/wconstab, https://github.com/eqy ghstack dependencies: #137161 Co-authored-by: Will Feng --- aten/src/ATen/core/ivalue_inl.h | 13 +++ test/distributed/test_c10d_nccl.py | 89 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 36 ++++---- 3 files changed, 122 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 63b58d80444517..96aef86d6686af 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -863,6 +863,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { Future& operator=(const Future&) = delete; Future& operator=(Future&&) = delete; + // Destructor + // Explicitly destroy events under device guard, otherwise it can lead to + // extra context being created on device 0. Reason: python garbage collector + // calls this destructor, but python GC does not have a device context, so a + // "default" one (usually on device 0) could be created when we go down the + // line of event destroy. + ~Future() override { + while (!events_.empty()) { + c10::OptionalDeviceGuard deviceGuard(events_.back().device()); + events_.pop_back(); + } + } + struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index dc372c4b271405..84452ac08a6256 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -450,6 +450,95 @@ def test_nan_check(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + def _helper_test_extra_cuda_context_by_nvml(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is avaiable. + pynvml provides python bindings for NVIDIA NVML functionalities. + Here we are interested in: nvmlDeviceGetComputeRunningProcesses + """ + import pynvml + + pynvml.nvmlInit() + + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + del work + handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank) + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + nprocs = len(processes) + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + self.assertEqual( + nprocs, + 1, + f"Found {nprocs} processes creating contexts on {device}, expecting 1 only", + ) + + def _helper_test_extra_cuda_context_by_memory(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable. + If extra context is created, it would manifest into device 0's memory usage. + """ + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + # Rank 0 takes a snapshot before collective -- this snapshot should have + # included rank 0's own context. + if self.rank == 0: + free, total = torch.cuda.mem_get_info(device) + used_before = float(total - free) + + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + free, total = torch.cuda.mem_get_info(device) + used_after = float(total - free) + del work + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + if self.rank == 0: + # If non-0 rank creates a context on device 0, this assert would + # fail because one context takes about 1 GB -- much more than the + # tensor size created in this test. + self.assertTrue( + used_after < used_before * 1.5, + f"{device} used {used_after} bytes after collective, " + f"50% more than the status before ({used_before} bytes). " + f"Extra CUDA context may have been created.", + ) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_extra_cuda_context(self): + # Check if non-0 ranks would create extra CUDA context on device 0 + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + device_id=device, + ) + try: + self._helper_test_extra_cuda_context_by_nvml() + except ModuleNotFoundError: + self._helper_test_extra_cuda_context_by_memory() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 50244ac2ead5b7..5adb14bd003387 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2790,6 +2790,11 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2800,7 +2805,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } op_id_++; - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2842,8 +2846,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - if (nanCheck) { for (const auto& input : inputs) { checkForNan(input, ncclStream); @@ -2968,6 +2970,19 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( bool avoidRecordStreams) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2982,14 +2997,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // op_id_ once per indvidual operation within the group op_id_++; - // Currently, the API permits one scenario where inputs.size() and - // outputs.size() are > 0. - // 1. If the call was a _coalesced call, all inputs must be on the same - // device. - // The group of nccl calls applies the collective separately to each input, - // but the group as a whole should be efficient, and might even execute as - // a single fused kernel. - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -3032,8 +3039,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) if (work->timingEnabled_) { @@ -3188,6 +3193,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } auto device = getDevice(tensor); + at::cuda::OptionalCUDAGuard gpuGuard(device); + std::string key; int p2pRank = 0, p2pTargetRank = 0; bool isSendRecvSelf = false; @@ -3309,9 +3316,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( /*isP2P=*/true); } - // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder if (enableNanCheck_ && opType == OpType::SEND) {