From 95f762869e766ba223dcfaa0ac022aa9ce1a695a Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 19 Nov 2024 02:06:29 +0000 Subject: [PATCH] [distributed] add PG APIs and general doc cleanups (#140853) Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as https://github.com/pytorch/rfcs/pull/71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140853 Approved by: https://github.com/kwen2501 --- docs/source/distributed.rst | 61 ++++---- torch/csrc/distributed/c10d/init.cpp | 212 +++++++++++++++++++++------ 2 files changed, 201 insertions(+), 72 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 5b3f60f97af42d..35c6c7c2a39c82 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -329,32 +329,6 @@ a github issue or RFC if this is a use case that's blocking you. -------------------------------------------------------------------------------- -Distributed Key-Value Store ---------------------------- - -The distributed package comes with a distributed key-value store, which can be -used to share information between processes in the group as well as to -initialize the distributed package in -:func:`torch.distributed.init_process_group` (by explicitly creating the store -as an alternative to specifying ``init_method``.) There are 3 choices for -Key-Value Stores: :class:`~torch.distributed.TCPStore`, -:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`. - -.. autoclass:: Store -.. autoclass:: TCPStore -.. autoclass:: HashStore -.. autoclass:: FileStore -.. autoclass:: PrefixStore - -.. autofunction:: torch.distributed.Store.set -.. autofunction:: torch.distributed.Store.get -.. autofunction:: torch.distributed.Store.add -.. autofunction:: torch.distributed.Store.compare_set -.. autofunction:: torch.distributed.Store.wait -.. autofunction:: torch.distributed.Store.num_keys -.. autofunction:: torch.distributed.Store.delete_key -.. autofunction:: torch.distributed.Store.set_timeout - Groups ------ @@ -386,6 +360,7 @@ distributed process group easily. :func:`~torch.distributed.device_mesh.init_dev used to create new DeviceMesh, with a mesh shape describing the device topology. .. autoclass:: torch.distributed.device_mesh.DeviceMesh + :members: Point-to-point communication ---------------------------- @@ -506,6 +481,7 @@ Collective functions .. autofunction:: monitored_barrier .. autoclass:: Work + :members: .. autoclass:: ReduceOp @@ -516,6 +492,39 @@ Collective functions :class:`~torch.distributed.ReduceOp` is recommended to use instead. + +Distributed Key-Value Store +--------------------------- + +The distributed package comes with a distributed key-value store, which can be +used to share information between processes in the group as well as to +initialize the distributed package in +:func:`torch.distributed.init_process_group` (by explicitly creating the store +as an alternative to specifying ``init_method``.) There are 3 choices for +Key-Value Stores: :class:`~torch.distributed.TCPStore`, +:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`. + +.. autoclass:: Store + :members: + :special-members: + +.. autoclass:: TCPStore + :members: + :special-members: __init__ + +.. autoclass:: HashStore + :members: + :special-members: __init__ + +.. autoclass:: FileStore + :members: + :special-members: __init__ + +.. autoclass:: PrefixStore + :members: + :special-members: __init__ + + Profiling Collective Communication ----------------------------------------- diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index bbcfd10e58e919..4a2701170b6a0a 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1539,7 +1539,8 @@ Example:: .def( py::init(), py::arg("file_name"), - py::arg("world_size") = -1) + py::arg("world_size") = -1, + R"(Creates a new FileStore.)") .def_property_readonly( "path", &::c10d::FileStore::getPath, @@ -1561,7 +1562,7 @@ Example:: >>> # Use any of the store methods after initialization >>> store.set("first_key", "first_value") )") - .def(py::init<>()); + .def(py::init<>(), R"(Creates a new HashStore.)"); #endif intrusive_ptr_class_<::c10d::TCPStore>( @@ -1639,7 +1640,8 @@ Example:: py::arg("multi_tenant") = false, py::arg("master_listen_fd") = py::none(), py::arg("use_libuv") = true, - py::call_guard()) + py::call_guard(), + R"(Creates a new TCPStore.)") .def_property_readonly( "host", &::c10d::TCPStore::getHost, @@ -1680,7 +1682,8 @@ that adds a prefix to each key inserted to the store. return new ::c10d::PrefixStore(prefix, std::move(store)); }), py::arg("prefix"), - py::arg("store")) + py::arg("store"), + R"(Creates a new PrefixStore.)") .def_property_readonly( "underlying_store", &::c10d::PrefixStore::getUnderlyingStore, @@ -1900,17 +1903,31 @@ communication mechanism. py::class_< ::c10d::ProcessGroup, c10::intrusive_ptr<::c10d::ProcessGroup>, - ::c10d::PyProcessGroup>(module, "ProcessGroup") - .def(py::init()) + ::c10d::PyProcessGroup>(module, "ProcessGroup", + R"(A ProcessGroup is a communication primitive that allows for + collective operations across a group of processes. + + This is a base class that provides the interface for all + ProcessGroups. It is not meant to be used directly, but rather + extended by subclasses.)") + .def( + py::init(), + py::arg("rank"), + py::arg("size"), + R"(Create a new ProcessGroup instance.)") .def( py::init< const c10::intrusive_ptr<::c10d::Store>&, int, int>(), - py::call_guard()) - .def("rank", &::c10d::ProcessGroup::getRank) - .def("size", &::c10d::ProcessGroup::getSize) - .def("name", &::c10d::ProcessGroup::getBackendName) + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard(), + R"(Create a new ProcessGroup instance.)") + .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") + .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") + .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") .def("_id", &::c10d::ProcessGroup::getID) .def( "_backend_id", @@ -1921,7 +1938,10 @@ communication mechanism. &::c10d::ProcessGroup::broadcast, py::arg("tensors"), py::arg("opts") = ::c10d::BroadcastOptions(), - py::call_guard()) + py::call_guard(), + R"(Broadcasts the tensor to all processes in the process group. + + See :func:`torch.distributed.broadcast for more details.)") .def( "broadcast", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -1934,13 +1954,19 @@ communication mechanism. }, py::arg("tensor"), py::arg("root"), - py::call_guard()) + py::call_guard(), + R"(Broadcasts the tensor to all processes in the process group. + + See :func:`torch.distributed.broadcast` for more details.)") .def( "allreduce", &::c10d::ProcessGroup::allreduce, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceOptions(), - py::call_guard()) + py::call_guard(), + R"(Allreduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.all_reduce` for more details.)") .def( "allreduce", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -1952,7 +1978,10 @@ communication mechanism. }, py::arg("tensors"), py::arg("op") = ::c10d::ReduceOp::SUM, - py::call_guard()) + py::call_guard(), + R"(Allreduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.all_reduce` for more details.)") .def( "allreduce", @@ -1966,20 +1995,29 @@ communication mechanism. }, py::arg("tensor"), py::arg("op") = ::c10d::ReduceOp::SUM, - py::call_guard()) + py::call_guard(), + R"(Allreduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.all_reduce` for more details.)") .def( "allreduce_coalesced", &::c10d::ProcessGroup::allreduce_coalesced, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), - py::call_guard()) + py::call_guard(), + R"(Allreduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.all_reduce` for more details.)") .def( "reduce", &::c10d::ProcessGroup::reduce, py::arg("tensors"), py::arg("opts") = ::c10d::ReduceOptions(), - py::call_guard()) + py::call_guard(), + R"(Reduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.reduce` for more details.)") .def( "reduce", @@ -1996,14 +2034,20 @@ communication mechanism. py::arg("tensor"), py::arg("root"), py::arg("op") = ::c10d::ReduceOp::SUM, - py::call_guard()) + py::call_guard(), + R"(Reduces the provided tensors across all processes in the process group. + + See :func:`torch.distributed.reduce` for more details.)") .def( "allgather", &::c10d::ProcessGroup::allgather, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::AllgatherOptions(), - py::call_guard()) + py::call_guard(), + R"(Allgathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_gather` for more details.)") .def( "allgather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2016,7 +2060,10 @@ communication mechanism. }, py::arg("output_tensors"), py::arg("input_tensor"), - py::call_guard()) + py::call_guard(), + R"(Allgathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_gather: for more details.)") .def( "_allgather_base", &::c10d::ProcessGroup::_allgather_base, @@ -2030,21 +2077,30 @@ communication mechanism. py::arg("output_lists"), py::arg("input_list"), py::arg("opts") = ::c10d::AllgatherOptions(), - py::call_guard()) + py::call_guard(), + R"(Allgathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_gather` for more details.)") .def( "allgather_into_tensor_coalesced", &::c10d::ProcessGroup::allgather_into_tensor_coalesced, py::arg("outputs"), py::arg("inputs"), py::arg("opts") = ::c10d::AllgatherOptions(), - py::call_guard()) + py::call_guard(), + R"(Allgathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_gather` for more details.)") .def( "gather", &::c10d::ProcessGroup::gather, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::GatherOptions(), - py::call_guard()) + py::call_guard(), + R"(Gathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.gather` for more details.)") .def( "gather", @@ -2061,14 +2117,20 @@ communication mechanism. py::arg("output_tensors"), py::arg("input_tensor"), py::arg("root"), - py::call_guard()) + py::call_guard(), + R"(Gathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.gather` for more details.)") .def( "scatter", &::c10d::ProcessGroup::scatter, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::ScatterOptions(), - py::call_guard()) + py::call_guard(), + R"(Scatters the input tensors from all processes across the process group. + + See :func:`torch.distributed.scatter` for more details.)") .def( "scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2084,14 +2146,20 @@ communication mechanism. py::arg("output_tensor"), py::arg("input_tensors"), py::arg("root"), - py::call_guard()) + py::call_guard(), + R"(Scatters the input tensors from all processes across the process group. + + See :func:`torch.distributed.scatter` for more details.)") .def( "reduce_scatter", &::c10d::ProcessGroup::reduce_scatter, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::ReduceScatterOptions(), - py::call_guard()) + py::call_guard(), + R"(Reduces and scatters the input tensors from all processes across the process group. + + See :func:`torch.distributed.reduce_scatter` for more details.)") .def( "reduce_scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2107,7 +2175,10 @@ communication mechanism. py::arg("output"), py::arg("input"), py::arg("op") = ::c10d::ReduceOp::SUM, - py::call_guard()) + py::call_guard(), + R"(Reduces and scatters the input tensors from all processes across the process group. + + See :func:`torch.distributed.reduce_scatter` for more details.)") .def( "_reduce_scatter_base", &::c10d::ProcessGroup::_reduce_scatter_base, @@ -2121,7 +2192,10 @@ communication mechanism. py::arg("outputs"), py::arg("inputs"), py::arg("opts") = ::c10d::ReduceScatterOptions(), - py::call_guard()) + py::call_guard(), + R"(Reduces and scatters the input tensors from all processes across the process group. + + See :func:`torch.distributed.reduce_scatter` for more details.)") .def( "alltoall_base", &::c10d::ProcessGroup::alltoall_base, @@ -2130,37 +2204,56 @@ communication mechanism. py::arg("output_split_sizes"), py::arg("input_split_sizes"), py::arg("opts") = ::c10d::AllToAllOptions(), - py::call_guard()) + py::call_guard(), + R"(Alltoalls the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_to_all for more details.)") .def( "alltoall", &::c10d::ProcessGroup::alltoall, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::AllToAllOptions(), - py::call_guard()) + py::call_guard(), + R"(Alltoalls the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_to_all` for more details.)") .def( "send", &::c10d::ProcessGroup::send, py::arg("tensors"), py::arg("dstRank"), py::arg("tag"), - py::call_guard()) + py::call_guard(), + R"(Sends the tensor to the specified rank. + + See :func:`torch.distributed.send` for more details.)") .def( "recv", &::c10d::ProcessGroup::recv, py::arg("tensors"), py::arg("srcRank"), py::arg("tag"), - py::call_guard()) + py::call_guard(), + R"(Receives the tensor from the specified rank. + + See :func:`torch.distributed.recv` for more details.)") .def( "recv_anysource", &::c10d::ProcessGroup::recvAnysource, - py::call_guard()) + py::call_guard(), + R"(Receives the tensor from any source. + + See :func:`torch.distributed.recv` for more details.)") .def( "barrier", &::c10d::ProcessGroup::barrier, py::arg("opts") = ::c10d::BarrierOptions(), - py::call_guard()) + py::call_guard(), + R"(Blocks until all processes in the group enter the call, and + then all leave the call together. + + See :func:`torch.distributed.barrier` for more details.)") .def( "_set_sequence_number_for_group", &::c10d::ProcessGroup::setSequenceNumberForGroup, @@ -2180,7 +2273,11 @@ communication mechanism. }, py::arg("timeout") = ::c10d::kUnsetTimeout, py::arg("wait_all_ranks") = false, - py::call_guard()) + py::call_guard(), + R"(Blocks until all processes in the group enter the call, and + then all leave the call together. + + See :func:`torch.distributed.monitored_barrier` for more details.)") .def_property_readonly( "_device_types", &::c10d::ProcessGroup::getDeviceTypes) .def( @@ -2326,7 +2423,10 @@ The hook must have the following signature: return ivalue.toCustomClass<::c10d::ProcessGroup>(); }); - py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType") + py::enum_<::c10d::ProcessGroup::BackendType>( + processGroup, + "BackendType", + R"(The type of the backend used for the process group.)") .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) @@ -2702,7 +2802,12 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). int, int, c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(), - py::call_guard()) + py::call_guard(), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("options"), + R"(Create a new ProcessGroupGloo instance.)") .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, @@ -2735,7 +2840,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("rank"), py::arg("size"), py::arg("timeout") = kProcessGroupDefaultTimeout, - py::call_guard()) + py::call_guard(), + R"(Create a new ProcessGroupGloo instance.)") .def( "_set_default_timeout", [](const c10::intrusive_ptr<::c10d::ProcessGroupGloo>& self, @@ -2744,7 +2850,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). }, py::arg("timeout"), py::call_guard()) - .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions); + .def_property_readonly( + "options", + &::c10d::ProcessGroupGloo::getOptions, + R"(Return the options used to create this ProcessGroupGloo instance.)"); // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process // group. It can be used to validate collective calls across processes by @@ -2777,7 +2886,12 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). int, int, c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(), - py::call_guard()) + py::call_guard(), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("options"), + R"(Create a new ProcessGroupNCCL instance.)") .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, @@ -2793,7 +2907,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("rank"), py::arg("size"), py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout, - py::call_guard()) + py::call_guard(), + R"(Create a new ProcessGroupNCCL instance.)") .def( "_shutdown", [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { @@ -2830,12 +2945,16 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("work"), py::arg("timeout")) .def_property_readonly( - "options", &::c10d::ProcessGroupNCCL::getOptions) - .def_property_readonly("uid", &::c10d::ProcessGroupNCCL::getUid) + "options", + &::c10d::ProcessGroupNCCL::getOptions, + R"(Return the options used to create this ProcessGroupNCCL instance.)") + .def_property_readonly( + "uid", &::c10d::ProcessGroupNCCL::getUid, R"(Return the uid.)") .def_property( "bound_device_id", &::c10d::ProcessGroupNCCL::getBoundDeviceId, - &::c10d::ProcessGroupNCCL::setBoundDeviceId) + &::c10d::ProcessGroupNCCL::setBoundDeviceId, + R"(Return the bound device id.)") .def( "perform_nocolor_split", &::c10d::ProcessGroupNCCL::performNocolorSplit) @@ -2846,7 +2965,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def( "abort", &::c10d::ProcessGroupNCCL::abort, - py::call_guard()) + py::call_guard(), + R"(Abort the process group.)") .def( "_is_initialized", &::c10d::ProcessGroupNCCL::isInitialized,