Skip to content

Commit

Permalink
[C10D] Support group ranks in P2POp and batch_isend_irecv (pytorch#14…
Browse files Browse the repository at this point in the history
…1054)

Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#141054
Approved by: https://github.com/kwen2501
  • Loading branch information
wconstab authored and pobin6 committed Dec 5, 2024
1 parent cd2be98 commit e1622b7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
34 changes: 34 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3869,6 +3869,40 @@ def test_send_recv_subgroup(self, async_op, group_rank):
else:
c10d.send(x, dst=self.rank - 1, group=subgroup)

@requires_nccl()
@skip_if_lt_x_gpu(4)
@parametrize("group_rank", [True, False])
def test_batch_send_recv_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
return
subgroup = self._init_two_pg2_subgroups(world_size)
device = torch.device("cuda:%d" % self.rank)
ops = []
if self.rank == 0 or self.rank == 2:
x = torch.empty((10,), device=device)
if group_rank:
ops.append(c10d.P2POp(dist.irecv, x, group=subgroup, group_peer=1))
else:
ops.append(
c10d.P2POp(dist.irecv, x, peer=self.rank + 1, group=subgroup)
)

for work in dist.batch_isend_irecv(ops):
work.wait()
expected = torch.ones((10,), device=device) * (self.rank + 1)
self.assertEqual(x, expected)
else:
x = torch.ones((10,), device=device) * self.rank
if group_rank:
ops.append(c10d.P2POp(dist.isend, x, group=subgroup, group_peer=0))
else:
ops.append(
c10d.P2POp(dist.isend, x, peer=self.rank - 1, group=subgroup)
)
for work in dist.batch_isend_irecv(ops):
work.wait()

@requires_nccl()
@skip_if_lt_x_gpu(4)
@parametrize("group_rank", [True, False])
Expand Down
48 changes: 34 additions & 14 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,57 +469,61 @@ class P2POp:
The type of ``op`` is either ``torch.distributed.isend`` or
``torch.distributed.irecv``.
tensor (Tensor): Tensor to send or receive.
peer (int): Destination or source rank.
peer (int, optional): Destination or source rank.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match send with recv.
group_peer (int, optional): Destination or source rank.
"""

def __init__(
self,
op: Callable,
tensor: torch.Tensor,
peer: int,
peer: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_peer: Optional[int] = None,
):
"""Init."""
self.op = op
self.tensor = tensor
self.peer = peer
self.group = group
self.group = _group_or_default_group(group)
self.peer = _canonicalize_group_rank(
self.group, peer, group_peer, return_global=True
)
self.tag = tag
self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer)

def __new__(
cls,
op: Callable,
tensor: torch.Tensor,
peer: int,
peer: Optional[int] = None,
group: Optional[ProcessGroup] = None,
tag: int = 0,
group_peer: Optional[int] = None,
):
"""Create and return a new instance of the class."""
_check_op(op)
_check_single_tensor(tensor, "tensor")

return object.__new__(cls)

def __repr__(self):
my_group_rank = get_rank(self.group)
peer_group_rank = (
get_group_rank(self.group, self.peer) if self.group else self.peer
)
op_name = self.op.__name__
group_name = self.group.group_name if self.group else "default_pg"
if "send" in op_name:
s = my_group_rank
d = peer_group_rank
d = self.group_peer
elif "recv" in op_name:
s = peer_group_rank
s = self.group_peer
d = my_group_rank
else:
return super().__repr__()

return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})"
return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})"


class _CollOp:
Expand Down Expand Up @@ -2545,7 +2549,7 @@ def _coalescing_manager(
work.wait() # type: ignore[possibly-undefined]


def batch_isend_irecv(p2p_op_list):
def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
"""
Send or Receive a batch of tensors asynchronously and return a list of requests.
Expand Down Expand Up @@ -2588,17 +2592,33 @@ def batch_isend_irecv(p2p_op_list):
_check_p2p_op_list(p2p_op_list)
group = p2p_op_list[0].group
device = p2p_op_list[0].tensor.device

def peer_kwarg(op: P2POp) -> Dict[str, int]:
key = "group_dst" if op.op == isend else "group_src"
return {key: op.group_peer}

if device.type == "cuda":
# NCCL style coalescing
with _coalescing_manager(group, device, async_ops=True) as cm:
for p2p_op in p2p_op_list:
p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
p2p_op.op(
p2p_op.tensor,
group=p2p_op.group,
tag=p2p_op.tag,
**peer_kwarg(p2p_op),
)

return cm.works
else:
# Backward support for Gloo
reqs = []
for p2p_op in p2p_op_list:
work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
work = p2p_op.op(
p2p_op.tensor,
group=p2p_op.group,
tag=p2p_op.tag,
**peer_kwarg(p2p_op),
)
if work:
reqs.append(work)
return reqs
Expand Down

0 comments on commit e1622b7

Please sign in to comment.