-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC-0042-torch-distributed-redesign #71
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: youkaichao <[email protected]>
for preview, please check https://github.com/youkaichao/rfcs/blob/master/RFC-0042-torch-distributed-redesign.md |
An important usecase for this, is dynamic prefill decode disaggregation: we have prefill instance and decode instance dynamically join the group, according to the workload. And they will send/recv kv caches from/to each other. there are other solutions, like using etcd for communicating metadata, and directly use device communication libraries like our own nccl wrapper. That means completely dropping |
The current global group is necessary for control plane operations over the cluster. It's conflating the notion of a cluster with that of communication groups so it would be great to separate the two. One aspect to make this feasible is whether it's possible to implement |
do you mean we have stateless version of process group |
Thanks for posting this RFC! I want to see if we can make changes to existing torch.distributed apis first to solve some/all of your problems. And then if needed, we can consider a new set of APIs (e.g. torch.distributed2). For the src/dst to send/recv, that is something that has been bugging us for a while and I suppose we could fix it in existing APIs without worrying about BC by simply adding new kwargs to the APIs, For the global group, I think this might be harder to solve but I'd like to get a document started with the different possibilities and pros/cons. cc @kwen2501 |
I think a lot of what's being asked here can be done with just a new entrypoint (rather than just init_process_group) and avoid having to create a new package That's largely what I'm doing in the torchft ProcessGroups -- just initializing the underlying PG without setting the global state. It is definitely a bit clunky (since it operates on the store API) but it's generally works just fine to instantiate a PG without calling i.e. in current PyTorch you can do from torch.distributed import ProcessGroupNCCL, TCPStore
store = TCPStore(
host_name=host,
port=int(port),
is_master=False,
wait_for_workers=False,
)
store = PrefixStore("my_custom_pg", store)
pg = ProcessGroupNCCL(store, rank=10, world_size=32)
pg.rank(), pg.size() This can be used completely in an object oriented way without relying on any "internal" apis. |
@youkaichao would you be happy to use the workflow @d4l3k proposed? or is there still something missing? @d4l3k is the PrefixStore needed such that each store can use a default UUID (does each store use UUID 0 or something)? I wonder if we should still provide a little bit of a helper here, (a) we could allow reusing the globally initialized TCPStore if it exists (or accept one as optional kwarg as alternative), (b) we could deal with UUID automatically somehow, and ensure that each PG still has a unique UUID somehow? |
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. [ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. ghstack-source-id: 80af56e697db1ef61667e84a72b37d67af4c58fe Pull Request resolved: #140460
@d4l3k that's a great idea. I actually tried it before. however, the problem is, you cannot use |
I'm also exploring an idea of using the tcp store to directly implement a new set of send/recv/broadcast operations, in https://github.com/vllm-project/vllm/blob/377b74fe877c7eb4632c2ca0778b9da9a5db8ae6/vllm/distributed/utils.py#L127 . it works locally, but sometimes hangs during initialization in the ci though. |
@youkaichao is the only reason that send/recv do not work because of the dst/src mapping issue? I started to prototype a possible fix for that today, I'll share it here shortly. Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks. But for certain use cases it could work. We have also been thinking about better support for control-plane communication cc @c-p-i-o |
For send/recv, yes, kind of. There are other more complicated cases, though. For example, and they are quite difficult to use if i have a standalone group that is not part of the global group. |
For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .
that would be great. |
ok, these look like the same thing to me. Basically, if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue. That's what it looks like to me, at least.
Well, i'm not sure what you mean about polling by default. But if i were to build send/recv on top of tcpstore, i think my 2 choices would be (1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout, (2) implement a new polling thread on the recv side that keeps checking whether a send-data has been posted. I was referring to path (2). I'm not sure if (1) is actually practical for performance reasons but we could check. |
for my use case, (1) is enough.
also need to take care of collective ops like allreduce and allgather. the goal is to support subgroups working by their own without any dependency on the global group. |
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. ghstack-source-id: 94de882b6524bd7e49b3f08be84a30cb8d9b4c38 Pull Request resolved: #140460
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. ghstack-source-id: 72264a21bf53bafd0b16b7cbb961aa91cc9b5992 Pull Request resolved: #140460
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. [ghstack-poisoned]
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. ghstack-source-id: 72264a21bf53bafd0b16b7cbb961aa91cc9b5992 Pull Request resolved: #140460
@youkaichao we don't document the ProcessGroup object APIs (I'm not sure why not, we really should) but if you use them directly it should work as expected for send/recv/broadcast as the ranks are PG local rather than global i.e.
vs
|
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. [ghstack-poisoned]
… send/recv" Partly addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. [ghstack-poisoned]
Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: pytorch#140853 Approved by: https://github.com/kwen2501
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 ghstack-source-id: 6991c41e3c488a767116ee2a9dc4f49f3a587559 Pull Request resolved: #141054
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 ghstack-source-id: 6f61786161a67b69b05f924d46766583df20fcb3 Pull Request resolved: #141054
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 Pull Request resolved: #141054 Approved by: https://github.com/kwen2501
) Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch#140827 Approved by: https://github.com/kwen2501
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140843 Approved by: https://github.com/kwen2501
…orch#140847) Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140847 Approved by: https://github.com/H-Huang ghstack dependencies: pytorch#140843
Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: pytorch#140853 Approved by: https://github.com/kwen2501
…1054) Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#141054 Approved by: https://github.com/kwen2501
Partly addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. Pull Request resolved: pytorch#140460 Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
) Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch#140827 Approved by: https://github.com/kwen2501
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140843 Approved by: https://github.com/kwen2501
…orch#140847) Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140847 Approved by: https://github.com/H-Huang ghstack dependencies: pytorch#140843
Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: pytorch#140853 Approved by: https://github.com/kwen2501
…1054) Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#141054 Approved by: https://github.com/kwen2501
Partly addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. Pull Request resolved: pytorch#140460 Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
) Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch#140827 Approved by: https://github.com/kwen2501
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140843 Approved by: https://github.com/kwen2501
…orch#140847) Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140847 Approved by: https://github.com/H-Huang ghstack dependencies: pytorch#140843
Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: pytorch#140853 Approved by: https://github.com/kwen2501
…1054) Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#141054 Approved by: https://github.com/kwen2501
Partly addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. Pull Request resolved: pytorch#140460 Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
) Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch#140827 Approved by: https://github.com/kwen2501
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140843 Approved by: https://github.com/kwen2501
…orch#140847) Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140847 Approved by: https://github.com/H-Huang ghstack dependencies: pytorch#140843
Doc updates: * This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 . * It also does some general cleanups to simplify the distributed.rst by using `:methods`. * It adds `__init__` definitions for the Stores * I've reordered things so the collective APIs are before the Store/PG apis Test plan: ``` lintrunner -a cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going ``` Pull Request resolved: pytorch#140853 Approved by: https://github.com/kwen2501
Also add missing mypy typing and a few asserts to make mypy happy ghstack-source-id: b32ee335e65a5ad069e33a0db8ee73f357e762b9 Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch/pytorch#140827
Addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. ghstack-source-id: 33ea136c24295f041c95fbe0f7e1f493981865ee Pull Request resolved: pytorch/pytorch#140460
No description provided.