Skip to content

Commit

Permalink
[C10D] support group_src/dst in broadcast/reduce ops (pytorch#140843)
Browse files Browse the repository at this point in the history
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
  • Loading branch information
wconstab authored and fmo-mt committed Dec 11, 2024
1 parent 793a0b4 commit 660e294
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 40 deletions.
51 changes: 40 additions & 11 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3840,7 +3840,8 @@ def test_gather_object_subgroup(self, group_rank):

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_reduce_subgroup(self):
@parametrize("group_rank", [True, False])
def test_reduce_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3849,10 +3850,16 @@ def test_reduce_subgroup(self):
x = torch.ones((10,), device=device) * self.rank
if self.rank == 0 or self.rank == 2:
expected = x + torch.ones((10,), device=device) * (self.rank + 1)
c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
if group_rank:
c10d.reduce(x, group_dst=0, group=subgroup, async_op=False)
else:
c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
self.assertEqual(x, expected)
else:
c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)
if group_rank:
c10d.reduce(x, group_dst=0, group=subgroup, async_op=False)
else:
c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down Expand Up @@ -3893,20 +3900,27 @@ def test_send_recv_subgroup(self, async_op, group_rank):

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_broadcast_subgroup(self):
@parametrize("group_rank", [True, False])
def test_broadcast_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
return
subgroup = self._init_two_pg2_subgroups(world_size)
device = torch.device("cuda:%d" % self.rank)
if self.rank == 0 or self.rank == 2:
x = torch.empty((10,), device=device)
c10d.broadcast(x, src=self.rank + 1, group=subgroup)
if group_rank:
c10d.broadcast(x, group_src=1, group=subgroup)
else:
c10d.broadcast(x, src=self.rank + 1, group=subgroup)
expected = torch.ones((10,), device=device) * (self.rank + 1)
self.assertEqual(x, expected)
else:
x = torch.ones((10,), device=device) * self.rank
c10d.broadcast(x, src=self.rank, group=subgroup)
if group_rank:
c10d.broadcast(x, group_src=1, group=subgroup)
else:
c10d.broadcast(x, src=self.rank, group=subgroup)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down Expand Up @@ -3939,7 +3953,10 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
"set_device",
[SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
)
def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
@parametrize("group_rank", [True, False])
def test_broadcast_object_list_subgroup(
self, set_device: SetDeviceMethod, group_rank
):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3951,14 +3968,26 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
device = torch.device("cuda:%d" % self.rank)
if self.rank == 0 or self.rank == 2:
x = [{}]
c10d.broadcast_object_list(
x, src=self.rank + 1, group=subgroup, device=device
)
if group_rank:
c10d.broadcast_object_list(
x, group_src=1, group=subgroup, device=device
)
else:
c10d.broadcast_object_list(
x, src=self.rank + 1, group=subgroup, device=device
)
expected = [{"rank": self.rank + 1}]
self.assertEqual(x, expected)
else:
x = [{"rank": self.rank}]
c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device)
if group_rank:
c10d.broadcast_object_list(
x, group_src=1, group=subgroup, device=device
)
else:
c10d.broadcast_object_list(
x, src=self.rank, group=subgroup, device=device
)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def _broadcast_bucket(
for assigned_rank in assigned_ranks:
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
if bucket_index in bucket_assignments:
send_tensor = bucket_assignments[bucket_index].tensor
assert send_tensor is not None
overlap_info.broadcast_handles.append(
dist.broadcast(
bucket_assignments[bucket_index].tensor,
send_tensor,
src=dist.get_global_rank(zero.process_group, assigned_rank),
group=zero.process_group,
async_op=True,
Expand Down
74 changes: 46 additions & 28 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,13 @@ def batch_isend_irecv(p2p_op_list):


@_exception_logger
def broadcast(tensor, src, group=None, async_op=False):
def broadcast(
tensor: torch.Tensor,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
group_src: Optional[int] = None,
):
"""
Broadcasts the tensor to the whole group.
Expand All @@ -2607,29 +2613,26 @@ def broadcast(tensor, src, group=None, async_op=False):
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
group_src (int): Source rank on ``group``. Must specify one of ``group_src``
and ``src`` but not both.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
group = _group_or_default_group(group)
group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("broadcast")
return

opts = BroadcastOptions()
opts.rootRank = src
opts.rootRank = group_src
opts.rootTensor = 0
opts.asyncOp = async_op

if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
work = default_pg.broadcast([tensor], opts)
else:
group_src_rank = get_group_rank(group, src)
opts.rootRank = group_src_rank
work = group.broadcast([tensor], opts)
work = group.broadcast([tensor], opts)
if async_op:
return work
else:
Expand Down Expand Up @@ -2783,7 +2786,14 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):


@_exception_logger
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
def reduce(
tensor: torch.Tensor,
dst: Optional[int] = None,
op=ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
group_dst: Optional[int] = None,
):
"""
Reduces the tensor data across all machines.
Expand All @@ -2799,29 +2809,25 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst``
and ``dst`` but not both.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
"""
group = _group_or_default_group(group)
group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False)
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("reduce")
return

opts = ReduceOptions()
opts.reduceOp = op
opts.rootRank = dst

if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
work = default_pg.reduce([tensor], opts)
else:
group_dst_rank = get_group_rank(group, dst)
opts.rootRank = group_dst_rank
work = group.reduce([tensor], opts)

opts.rootRank = group_dst
work = group.reduce([tensor], opts)
if async_op:
return work
else:
Expand Down Expand Up @@ -3270,7 +3276,13 @@ def recv_object_list(object_list, src=None, group=None, device=None):


@_exception_logger
def broadcast_object_list(object_list, src=0, group=None, device=None):
def broadcast_object_list(
object_list: List[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
):
"""
Broadcasts picklable objects in ``object_list`` to the whole group.
Expand All @@ -3289,6 +3301,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
device (``torch.device``, optional): If not None, the objects are
serialized and converted to tensors which are moved to the
``device`` before broadcasting. Default is ``None``.
group_src (int): Source rank on ``group``. Must not specify one of ``group_src``
and ``src`` but not both.
Returns:
``None``. If rank is part of the group, ``object_list`` will contain the
Expand Down Expand Up @@ -3331,6 +3345,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
>>> objects
['foo', 12, {1: 2}]
"""
group = _group_or_default_group(group)
if src is None and group_src is None:
src = 0
global_src = _canonicalize_group_rank(group, src, group_src, return_global=True)
if _rank_not_in_group(group):
_warn_not_in_group("broadcast_object_list")
return
Expand All @@ -3342,9 +3360,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
current_device = device or _get_object_coll_device(group)
my_rank = get_rank()
my_global_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if my_global_rank == global_src:
tensor_list, size_list = zip(
*[_object_to_tensor(obj, current_device, group) for obj in object_list]
)
Expand All @@ -3355,12 +3373,12 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)
broadcast(object_sizes_tensor, src=global_src, group=group)

# Concatenate and broadcast serialized object tensors
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
# has only one element, we can skip the copy.
if my_rank == src:
if my_global_rank == global_src:
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
object_tensor = tensor_list[0]
else:
Expand All @@ -3372,10 +3390,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
device=current_device,
)

broadcast(object_tensor, src=src, group=group)
broadcast(object_tensor, src=global_src, group=group)
# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
if my_global_rank != global_src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
Expand Down

0 comments on commit 660e294

Please sign in to comment.