From 660e294a158fa9ac83df27dcd478d97072fc7a1b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 18 Nov 2024 11:33:21 -0800 Subject: [PATCH] [C10D] support group_src/dst in broadcast/reduce ops (#140843) Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140843 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 51 ++++++++++--- .../ddp_comm_hooks/ddp_zero_hook.py | 4 +- torch/distributed/distributed_c10d.py | 74 ++++++++++++------- 3 files changed, 89 insertions(+), 40 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 504b18944e704..82d52efd4b1ac 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3840,7 +3840,8 @@ def test_gather_object_subgroup(self, group_rank): @requires_nccl() @skip_if_lt_x_gpu(4) - def test_reduce_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_reduce_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: return @@ -3849,10 +3850,16 @@ def test_reduce_subgroup(self): x = torch.ones((10,), device=device) * self.rank if self.rank == 0 or self.rank == 2: expected = x + torch.ones((10,), device=device) * (self.rank + 1) - c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False) + if group_rank: + c10d.reduce(x, group_dst=0, group=subgroup, async_op=False) + else: + c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False) self.assertEqual(x, expected) else: - c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False) + if group_rank: + c10d.reduce(x, group_dst=0, group=subgroup, async_op=False) + else: + c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False) @requires_nccl() @skip_if_lt_x_gpu(4) @@ -3893,7 +3900,8 @@ def test_send_recv_subgroup(self, async_op, group_rank): @requires_nccl() @skip_if_lt_x_gpu(4) - def test_broadcast_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_broadcast_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: return @@ -3901,12 +3909,18 @@ def test_broadcast_subgroup(self): device = torch.device("cuda:%d" % self.rank) if self.rank == 0 or self.rank == 2: x = torch.empty((10,), device=device) - c10d.broadcast(x, src=self.rank + 1, group=subgroup) + if group_rank: + c10d.broadcast(x, group_src=1, group=subgroup) + else: + c10d.broadcast(x, src=self.rank + 1, group=subgroup) expected = torch.ones((10,), device=device) * (self.rank + 1) self.assertEqual(x, expected) else: x = torch.ones((10,), device=device) * self.rank - c10d.broadcast(x, src=self.rank, group=subgroup) + if group_rank: + c10d.broadcast(x, group_src=1, group=subgroup) + else: + c10d.broadcast(x, src=self.rank, group=subgroup) @requires_nccl() @skip_if_lt_x_gpu(4) @@ -3939,7 +3953,10 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod): "set_device", [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT], ) - def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod): + @parametrize("group_rank", [True, False]) + def test_broadcast_object_list_subgroup( + self, set_device: SetDeviceMethod, group_rank + ): world_size = 4 if self.rank >= world_size: return @@ -3951,14 +3968,26 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod): device = torch.device("cuda:%d" % self.rank) if self.rank == 0 or self.rank == 2: x = [{}] - c10d.broadcast_object_list( - x, src=self.rank + 1, group=subgroup, device=device - ) + if group_rank: + c10d.broadcast_object_list( + x, group_src=1, group=subgroup, device=device + ) + else: + c10d.broadcast_object_list( + x, src=self.rank + 1, group=subgroup, device=device + ) expected = [{"rank": self.rank + 1}] self.assertEqual(x, expected) else: x = [{"rank": self.rank}] - c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device) + if group_rank: + c10d.broadcast_object_list( + x, group_src=1, group=subgroup, device=device + ) + else: + c10d.broadcast_object_list( + x, src=self.rank, group=subgroup, device=device + ) @requires_nccl() @skip_if_lt_x_gpu(4) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 6db6d1831b1fd..2f1000618337d 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -88,9 +88,11 @@ def _broadcast_bucket( for assigned_rank in assigned_ranks: bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] if bucket_index in bucket_assignments: + send_tensor = bucket_assignments[bucket_index].tensor + assert send_tensor is not None overlap_info.broadcast_handles.append( dist.broadcast( - bucket_assignments[bucket_index].tensor, + send_tensor, src=dist.get_global_rank(zero.process_group, assigned_rank), group=zero.process_group, async_op=True, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index a2a3eeaf2b538..319a8e942b04f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2593,7 +2593,13 @@ def batch_isend_irecv(p2p_op_list): @_exception_logger -def broadcast(tensor, src, group=None, async_op=False): +def broadcast( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): """ Broadcasts the tensor to the whole group. @@ -2607,29 +2613,26 @@ def broadcast(tensor, src, group=None, async_op=False): group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op + group_src (int): Source rank on ``group``. Must specify one of ``group_src`` + and ``src`` but not both. Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group """ + group = _group_or_default_group(group) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("broadcast") return opts = BroadcastOptions() - opts.rootRank = src + opts.rootRank = group_src opts.rootTensor = 0 opts.asyncOp = async_op - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.broadcast([tensor], opts) - else: - group_src_rank = get_group_rank(group, src) - opts.rootRank = group_src_rank - work = group.broadcast([tensor], opts) + work = group.broadcast([tensor], opts) if async_op: return work else: @@ -2783,7 +2786,14 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): @_exception_logger -def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): +def reduce( + tensor: torch.Tensor, + dst: Optional[int] = None, + op=ReduceOp.SUM, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): """ Reduces the tensor data across all machines. @@ -2799,12 +2809,16 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op + group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst`` + and ``dst`` but not both. Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("reduce") @@ -2812,16 +2826,8 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): opts = ReduceOptions() opts.reduceOp = op - opts.rootRank = dst - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.reduce([tensor], opts) - else: - group_dst_rank = get_group_rank(group, dst) - opts.rootRank = group_dst_rank - work = group.reduce([tensor], opts) - + opts.rootRank = group_dst + work = group.reduce([tensor], opts) if async_op: return work else: @@ -3270,7 +3276,13 @@ def recv_object_list(object_list, src=None, group=None, device=None): @_exception_logger -def broadcast_object_list(object_list, src=0, group=None, device=None): +def broadcast_object_list( + object_list: List[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_src: Optional[int] = None, +): """ Broadcasts picklable objects in ``object_list`` to the whole group. @@ -3289,6 +3301,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): device (``torch.device``, optional): If not None, the objects are serialized and converted to tensors which are moved to the ``device`` before broadcasting. Default is ``None``. + group_src (int): Source rank on ``group``. Must not specify one of ``group_src`` + and ``src`` but not both. Returns: ``None``. If rank is part of the group, ``object_list`` will contain the @@ -3331,6 +3345,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): >>> objects ['foo', 12, {1: 2}] """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) if _rank_not_in_group(group): _warn_not_in_group("broadcast_object_list") return @@ -3342,9 +3360,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): # case it is not ``None`` we move the size and object tensors to be # broadcasted to this device. current_device = device or _get_object_coll_device(group) - my_rank = get_rank() + my_global_rank = get_rank() # Serialize object_list elements to tensors on src rank. - if my_rank == src: + if my_global_rank == global_src: tensor_list, size_list = zip( *[_object_to_tensor(obj, current_device, group) for obj in object_list] ) @@ -3355,12 +3373,12 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): ) # Broadcast object sizes - broadcast(object_sizes_tensor, src=src, group=group) + broadcast(object_sizes_tensor, src=global_src, group=group) # Concatenate and broadcast serialized object tensors # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list # has only one element, we can skip the copy. - if my_rank == src: + if my_global_rank == global_src: if len(tensor_list) == 1: # type: ignore[possibly-undefined] object_tensor = tensor_list[0] else: @@ -3372,10 +3390,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): device=current_device, ) - broadcast(object_tensor, src=src, group=group) + broadcast(object_tensor, src=global_src, group=group) # Deserialize objects using their stored sizes. offset = 0 - if my_rank != src: + if my_global_rank != global_src: for i, obj_size in enumerate(object_sizes_tensor): obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8)