diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 82d52efd4b1ac8..e907565753a322 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3928,7 +3928,10 @@ def test_broadcast_subgroup(self, group_rank): "set_device", [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT], ) - def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod): + @parametrize("group_rank", [True, False]) + def test_send_recv_object_list_subgroup( + self, set_device: SetDeviceMethod, group_rank + ): world_size = 4 if self.rank >= world_size: return @@ -3940,12 +3943,22 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod): device = torch.device("cuda:%d" % self.rank) if self.rank == 0 or self.rank == 2: x = [{}] - c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device) + if group_rank: + c10d.recv_object_list(x, group_src=1, group=subgroup, device=device) + else: + c10d.recv_object_list( + x, src=self.rank + 1, group=subgroup, device=device + ) expected = [{"rank": self.rank + 1}] self.assertEqual(x, expected) else: x = [{"rank": self.rank}] - c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device) + if group_rank: + c10d.send_object_list(x, group_dst=0, group=subgroup, device=device) + else: + c10d.send_object_list( + x, dst=self.rank - 1, group=subgroup, device=device + ) @requires_nccl() @skip_if_lt_x_gpu(4) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 319a8e942b04f8..3e5c91984535e1 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3087,7 +3087,13 @@ def gather_object( @_exception_logger -def send_object_list(object_list, dst, group=None, device=None): +def send_object_list( + object_list: List[Any], + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_dst: Optional[int] = None, +): """ Sends picklable objects in ``object_list`` synchronously. @@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None): device (``torch.device``, optional): If not None, the objects are serialized and converted to tensors which are moved to the ``device`` before sending. Default is ``None``. - + group_dst (int, optional): Destination rank on ``group``. + Must specify one of ``dst`` and ``group_dst`` but not both Returns: ``None``. @@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None): >>> objects ['foo', 12, {1: 2}] """ - if get_rank() == dst: - raise ValueError( - "Invalid destination rank: destination rank should not be the same as " - "the rank of the current process." - ) + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") if _rank_not_in_group(group): _warn_not_in_group("send_object_list") @@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None): object_sizes_tensor = torch.cat(size_list) # Send object sizes - send(object_sizes_tensor, dst=dst, group=group) + send(object_sizes_tensor, group_dst=group_dst, group=group) # Concatenate and send serialized object tensors # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list @@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None): else: object_tensor = torch.cat(tensor_list) - send(object_tensor, dst=dst, group=group) + send(object_tensor, group_dst=group_dst, group=group) @_exception_logger -def recv_object_list(object_list, src=None, group=None, device=None): +def recv_object_list( + object_list: List[Any], + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + group_src: Optional[int] = None, +): """ Receives picklable objects in ``object_list`` synchronously. @@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): the default process group will be used. Default is ``None``. device (``torch.device``, optional): If not None, receives on this device. Default is ``None``. + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. Returns: Sender rank. -1 if rank is not part of the group. If rank is part of the group, @@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): ) # Receive object sizes - rank_sizes = recv(object_sizes_tensor, src=src, group=group) + rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] @@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): device=current_device, ) - rank_objects = recv(object_tensor, src=src, group=group) + rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) assert ( rank_sizes == rank_objects ), "Mismatch in return ranks for object sizes and objects."