Skip to content

Commit

Permalink
[C10D] Support group_dst/group_src in c10d send/recv (pytorch#140460)
Browse files Browse the repository at this point in the history
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: pytorch#140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
  • Loading branch information
wconstab authored and fmo-mt committed Dec 11, 2024
1 parent 6ff2903 commit 08825be
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 52 deletions.
9 changes: 9 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,11 +1775,20 @@ def test_send_recv(self):

with self.assertRaises(ValueError):
dist.send(input_tensor, dist.get_rank())
with self.assertRaises(ValueError):
dist.send(input_tensor, group_dst=dist.get_rank())

with self.assertRaises(ValueError):
dist.send(input_tensor, dist.get_rank(), group_dst=dist.get_rank())
with self.assertRaises(ValueError):
dist.send(input_tensor)

# test recv
input_tensor = torch.zeros(2, 2)
dist.recv(input_tensor, (self.rank + 1) % self.world_size)
self.assertEqual(input_tensor, torch.zeros(2, 2) + 2)
with self.assertRaises(ValueError):
dist.recv(input_tensor, src=0, group_src=0)

dist.barrier()
# intentionally not calling into `destroy_process_group` as not all
Expand Down
23 changes: 18 additions & 5 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3825,8 +3825,9 @@ def test_reduce_subgroup(self):

@requires_nccl()
@skip_if_lt_x_gpu(4)
@parametrize("group_rank", [True, False])
@parametrize("async_op", [True, False])
def test_send_recv_subgroup(self, async_op):
def test_send_recv_subgroup(self, async_op, group_rank):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3835,17 +3836,29 @@ def test_send_recv_subgroup(self, async_op):
if self.rank == 0 or self.rank == 2:
x = torch.empty((10,), device=device)
if async_op:
c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
if group_rank:
c10d.irecv(x, group_src=1, group=subgroup).wait()
else:
c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
else:
c10d.recv(x, src=self.rank + 1, group=subgroup)
if group_rank:
c10d.recv(x, group_src=1, group=subgroup)
else:
c10d.recv(x, src=self.rank + 1, group=subgroup)
expected = torch.ones((10,), device=device) * (self.rank + 1)
self.assertEqual(x, expected)
else:
x = torch.ones((10,), device=device) * self.rank
if async_op:
c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
if group_rank:
c10d.isend(x, group_dst=0, group=subgroup).wait()
else:
c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
else:
c10d.send(x, dst=self.rank - 1, group=subgroup)
if group_rank:
c10d.send(x, group_dst=0, group=subgroup)
else:
c10d.send(x, dst=self.rank - 1, group=subgroup)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down
119 changes: 72 additions & 47 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,38 @@ def _check_tensor_list(param, param_name) -> None:
)


def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup:
if group is None or group is GroupMember.WORLD:
group = _get_default_group()
return group


def _canonicalize_group_rank(
group: ProcessGroup,
global_rank: Optional[int] = None,
group_rank: Optional[int] = None,
) -> int:
"""
Helper method to take _either_ a global rank or a group rank and produce a group rank.
"""
if group_rank is not None:
if global_rank is not None:
raise ValueError("Can't specify both group_rank and global_rank")
else:
if global_rank is None:
raise ValueError("Must specify global_rank or group_rank")
group_rank = get_group_rank(group, global_rank)
return group_rank


def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str):
if group.rank() == rank:
raise ValueError(
f"Invalid {rank_type} rank: {rank_type} rank should not be the same as "
"the rank of the current process."
)


def _as_iterable(obj) -> collections.abc.Iterable:
return obj if isinstance(obj, list) else (obj,)

Expand Down Expand Up @@ -2217,7 +2249,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:


def isend(
tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
tensor: torch.Tensor,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_dst: Optional[int] = None,
) -> Optional[Work]:
"""
Send a tensor asynchronously.
Expand All @@ -2229,18 +2265,23 @@ def isend(
.. warning::
``tag`` is not supported with the NCCL backend.
Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
Args:
tensor (Tensor): Tensor to send.
dst (int): Destination rank on global process group (regardless of ``group`` argument)
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match send with remote recv
group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
Returns:
A distributed request object.
None, if not part of the group
"""
group = _group_or_default_group(group)
group_dst = _canonicalize_group_rank(group, dst, group_dst)
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("isend")
Expand All @@ -2249,34 +2290,32 @@ def isend(
if tensor.is_complex():
tensor = torch.view_as_real(tensor)

if group is None or group is GroupMember.WORLD:
pg = _get_default_group()
else:
pg = group
dst = get_group_rank(pg, dst)

return pg.send([tensor], dst, tag)
return group.send([tensor], group_dst, tag)


def irecv(
tensor: torch.Tensor,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_src: Optional[int] = None,
) -> Optional[Work]:
"""
Receives a tensor asynchronously.
.. warning::
``tag`` is not supported with the NCCL backend.
Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
Args:
tensor (Tensor): Tensor to fill with received data.
src (int, optional): Source rank on global process group (regardless of ``group`` argument).
Will receive from any process if unspecified.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match recv with remote send
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
Returns:
A distributed request object.
Expand All @@ -2291,24 +2330,21 @@ def irecv(
if tensor.is_complex():
tensor = torch.view_as_real(tensor)

if group is None or group is GroupMember.WORLD:
pg = _get_default_group()
else:
pg = group

if src is None:
return pg.recv_anysource([tensor], tag)
group = _group_or_default_group(group)
if src is None and group_src is None:
return group.recv_anysource([tensor], tag)
else:
if pg is GroupMember.WORLD:
return pg.recv([tensor], src, tag)
else:
group_src_rank = get_group_rank(pg, src)
return pg.recv([tensor], group_src_rank, tag)
group_src = _canonicalize_group_rank(group, src, group_src)
return group.recv([tensor], group_src, tag)


@_exception_logger
def send(
tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
tensor: torch.Tensor,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_dst: Optional[int] = None,
) -> None:
"""
Send a tensor synchronously.
Expand All @@ -2323,14 +2359,12 @@ def send(
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match send with remote recv
group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``.
"""
if get_rank() == dst:
raise ValueError(
"Invalid destination rank: destination rank should not be the same as "
"the rank of the current process."
)

group = _group_or_default_group(group)
group_dst = _canonicalize_group_rank(group, dst, group_dst)
_check_not_self_rank(group, group_dst, "destination")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("send")
Expand All @@ -2339,12 +2373,7 @@ def send(
if tensor.is_complex():
tensor = torch.view_as_real(tensor)

if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
default_pg.send([tensor], dst, tag).wait()
else:
group_dst_rank = get_group_rank(group, dst)
group.send([tensor], group_dst_rank, tag).wait()
group.send([tensor], group_dst, tag).wait()


@_exception_logger
Expand All @@ -2353,6 +2382,7 @@ def recv(
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_src: Optional[int] = None,
) -> int:
"""
Receives a tensor synchronously.
Expand All @@ -2367,7 +2397,7 @@ def recv(
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match recv with remote send
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
Returns:
Sender rank
-1, if not part of the group
Expand All @@ -2381,23 +2411,18 @@ def recv(
if tensor.is_complex():
tensor = torch.view_as_real(tensor)

pg = group or _get_default_group()
group = _group_or_default_group(group)

if src is None:
work = pg.recv_anysource([tensor], tag)
if src is None and group_src is None:
work = group.recv_anysource([tensor], tag)
work.wait()
src_rank = work._source_rank()
if group is None or group is GroupMember.WORLD:
return src_rank
else:
return get_global_rank(pg, src_rank)
return get_global_rank(group, src_rank)
else:
if group is None or group is GroupMember.WORLD:
pg.recv([tensor], src, tag).wait()
else:
group_src_rank = get_group_rank(pg, src)
pg.recv([tensor], group_src_rank, tag).wait()
return src
group_src = _canonicalize_group_rank(group, src, group_src)
_check_not_self_rank(group, group_src, "source")
group.recv([tensor], group_src, tag).wait()
return get_global_rank(group, group_src)


class _IllegalWork(Work):
Expand Down

0 comments on commit 08825be

Please sign in to comment.