Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0042-torch-distributed-redesign #71

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

youkaichao
Copy link

No description provided.

Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Author

for preview, please check https://github.com/youkaichao/rfcs/blob/master/RFC-0042-torch-distributed-redesign.md

@youkaichao
Copy link
Author

cc @ezyang @wconstab @kwen2501

@youkaichao
Copy link
Author

An important usecase for this, is dynamic prefill decode disaggregation: we have prefill instance and decode instance dynamically join the group, according to the workload. And they will send/recv kv caches from/to each other.

there are other solutions, like using etcd for communicating metadata, and directly use device communication libraries like our own nccl wrapper. That means completely dropping torch.distributed from our codebase though, and will be our last resort. We do want to use PyTorch as much as we can.

@kumpera
Copy link

kumpera commented Nov 11, 2024

The current global group is necessary for control plane operations over the cluster.

It's conflating the notion of a cluster with that of communication groups so it would be great to separate the two.

One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.

@youkaichao
Copy link
Author

One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.

do you mean we have stateless version of process group torch.distributed2 , and re-implement global group in torch.distributed ? That can be a great idea!

@wconstab
Copy link

Thanks for posting this RFC!

I want to see if we can make changes to existing torch.distributed apis first to solve some/all of your problems. And then if needed, we can consider a new set of APIs (e.g. torch.distributed2).

For the src/dst to send/recv, that is something that has been bugging us for a while and I suppose we could fix it in existing APIs without worrying about BC by simply adding new kwargs to the APIs, group_src or group_dst which would be exclusive with src and dst - e.g. you can pass one or the other but not both.

For the global group, I think this might be harder to solve but I'd like to get a document started with the different possibilities and pros/cons. cc @kwen2501

@d4l3k
Copy link
Member

d4l3k commented Nov 12, 2024

I think a lot of what's being asked here can be done with just a new entrypoint (rather than just init_process_group) and avoid having to create a new package

That's largely what I'm doing in the torchft ProcessGroups -- just initializing the underlying PG without setting the global state. It is definitely a bit clunky (since it operates on the store API) but it's generally works just fine to instantiate a PG without calling init_process_group. https://github.com/pytorch-labs/torchft/blob/main/torchft/process_group.py

i.e. in current PyTorch you can do

from torch.distributed import ProcessGroupNCCL, TCPStore

store = TCPStore(
    host_name=host,
    port=int(port),
    is_master=False,
    wait_for_workers=False,
)
store = PrefixStore("my_custom_pg", store)

pg = ProcessGroupNCCL(store, rank=10, world_size=32)

pg.rank(), pg.size()

This can be used completely in an object oriented way without relying on any "internal" apis.

@wconstab
Copy link

wconstab commented Nov 12, 2024

@youkaichao would you be happy to use the workflow @d4l3k proposed? or is there still something missing?

@d4l3k is the PrefixStore needed such that each store can use a default UUID (does each store use UUID 0 or something)? I wonder if we should still provide a little bit of a helper here, (a) we could allow reusing the globally initialized TCPStore if it exists (or accept one as optional kwarg as alternative), (b) we could deal with UUID automatically somehow, and ensure that each PG still has a unique UUID somehow?

wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 12, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 12, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

ghstack-source-id: 80af56e697db1ef61667e84a72b37d67af4c58fe
Pull Request resolved: #140460
@youkaichao
Copy link
Author

@d4l3k that's a great idea. I actually tried it before. however, the problem is, you cannot use pg.send/recv . there are some exceptions like torch.distributed.all_reduce that can work with these standalone groups, but torch.distributed.send/recv do not work.

@youkaichao
Copy link
Author

I'm also exploring an idea of using the tcp store to directly implement a new set of send/recv/broadcast operations, in https://github.com/vllm-project/vllm/blob/377b74fe877c7eb4632c2ca0778b9da9a5db8ae6/vllm/distributed/utils.py#L127 . it works locally, but sometimes hangs during initialization in the ci though.

@wconstab
Copy link

@youkaichao is the only reason that send/recv do not work because of the dst/src mapping issue? I started to prototype a possible fix for that today, I'll share it here shortly.

Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks. But for certain use cases it could work. We have also been thinking about better support for control-plane communication cc @c-p-i-o

@youkaichao
Copy link
Author

is the only reason that send/recv do not work because of the dst/src mapping issue?

For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:

https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L2580

and broadcast_object_list:

https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L3239C5-L3239C26

they are quite difficult to use if i have a standalone group that is not part of the global group.

@youkaichao
Copy link
Author

youkaichao commented Nov 13, 2024

Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks.

For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .

We have also been thinking about better support for control-plane communication

that would be great.

@wconstab
Copy link

For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:

ok, these look like the same thing to me. Basically, if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue. That's what it looks like to me, at least.

For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .

Well, i'm not sure what you mean about polling by default. But if i were to build send/recv on top of tcpstore, i think my 2 choices would be (1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout, (2) implement a new polling thread on the recv side that keeps checking whether a send-data has been posted. I was referring to path (2). I'm not sure if (1) is actually practical for performance reasons but we could check.

@youkaichao
Copy link
Author

(1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout

for my use case, (1) is enough.

if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue

also need to take care of collective ops like allreduce and allgather. the goal is to support subgroups working by their own without any dependency on the global group.

wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

ghstack-source-id: 94de882b6524bd7e49b3f08be84a30cb8d9b4c38
Pull Request resolved: #140460
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

ghstack-source-id: 72264a21bf53bafd0b16b7cbb961aa91cc9b5992
Pull Request resolved: #140460
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 14, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

ghstack-source-id: 72264a21bf53bafd0b16b7cbb961aa91cc9b5992
Pull Request resolved: #140460
@d4l3k
Copy link
Member

d4l3k commented Nov 14, 2024

@youkaichao we don't document the ProcessGroup object APIs (I'm not sure why not, we really should) but if you use them directly it should work as expected for send/recv/broadcast as the ranks are PG local rather than global

https://github.com/pytorch/pytorch/blob/f98c601efe9b426bf85d48d4949cddd01b744e55/torch/csrc/distributed/c10d/init.cpp#L2122-L2128

i.e.

pg = ProcessGroupNCCL(store, rank=10, world_size=32)
pg.send(..., local_rank, "").wait()

vs

import torch.distributed as dist

dist.send(..., global_rank, group=pg)

wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 15, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 15, 2024
… send/recv"


Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

[ghstack-poisoned]
cyyever pushed a commit to cyyever/pytorch that referenced this pull request Nov 20, 2024
Doc updates:

* This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 .
* It also does some general cleanups to simplify the distributed.rst by using `:methods`.
* It adds `__init__` definitions for the Stores
* I've reordered things so the collective APIs are before the Store/PG apis

Test plan:

```
lintrunner -a
cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going
```

Pull Request resolved: pytorch#140853
Approved by: https://github.com/kwen2501
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 20, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 20, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

ghstack-source-id: 6991c41e3c488a767116ee2a9dc4f49f3a587559
Pull Request resolved: #141054
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
wconstab added a commit to pytorch/pytorch that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

ghstack-source-id: 6f61786161a67b69b05f924d46766583df20fcb3
Pull Request resolved: #141054
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 21, 2024
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

Pull Request resolved: #141054
Approved by: https://github.com/kwen2501
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
Doc updates:

* This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 .
* It also does some general cleanups to simplify the distributed.rst by using `:methods`.
* It adds `__init__` definitions for the Stores
* I've reordered things so the collective APIs are before the Store/PG apis

Test plan:

```
lintrunner -a
cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going
```

Pull Request resolved: pytorch#140853
Approved by: https://github.com/kwen2501
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
…1054)

Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#141054
Approved by: https://github.com/kwen2501
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: pytorch#140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
Doc updates:

* This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 .
* It also does some general cleanups to simplify the distributed.rst by using `:methods`.
* It adds `__init__` definitions for the Stores
* I've reordered things so the collective APIs are before the Store/PG apis

Test plan:

```
lintrunner -a
cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going
```

Pull Request resolved: pytorch#140853
Approved by: https://github.com/kwen2501
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
…1054)

Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#141054
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: pytorch#140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Doc updates:

* This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 .
* It also does some general cleanups to simplify the distributed.rst by using `:methods`.
* It adds `__init__` definitions for the Stores
* I've reordered things so the collective APIs are before the Store/PG apis

Test plan:

```
lintrunner -a
cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going
```

Pull Request resolved: pytorch#140853
Approved by: https://github.com/kwen2501
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…1054)

Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#141054
Approved by: https://github.com/kwen2501
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this pull request Dec 11, 2024
Partly addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: pytorch#140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this pull request Dec 11, 2024
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this pull request Dec 11, 2024
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460
Pull Request resolved: pytorch#140843
Approved by: https://github.com/kwen2501
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this pull request Dec 11, 2024
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this pull request Dec 11, 2024
Doc updates:

* This adds documentation for the object oriented ProcessGroup APIs that are being used in torchft as well as pytorch/rfcs#71 .
* It also does some general cleanups to simplify the distributed.rst by using `:methods`.
* It adds `__init__` definitions for the Stores
* I've reordered things so the collective APIs are before the Store/PG apis

Test plan:

```
lintrunner -a
cd docs && sphinx-autobuild source build/ -j auto -WT --keep-going
```

Pull Request resolved: pytorch#140853
Approved by: https://github.com/kwen2501
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Also add missing mypy typing and a few asserts to make mypy happy
ghstack-source-id: b32ee335e65a5ad069e33a0db8ee73f357e762b9

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch/pytorch#140827
fightingand pushed a commit to fightingand/pytorch that referenced this pull request Dec 20, 2024
Addressing RFC 0042 (pytorch/rfcs#71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.

ghstack-source-id: 33ea136c24295f041c95fbe0f7e1f493981865ee
Pull Request resolved: pytorch/pytorch#140460
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants