-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allgather with DID loop split #3284
Conversation
3c5c68e
to
5ef8ff4
Compare
!test |
!test |
cc @xwang233 the testing infra appears to be problematic for H100: https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/125617985 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this pr! I am still trying to fully understand how the logic works, but let me post a series of minor comments in the meantime.
csrc/multidevice/utils.cpp
Outdated
|
||
const auto iter = std::find( | ||
tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]); | ||
NVF_ERROR( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure to understand why this check is needed. Isn't it true that by assumption what is returned by getInputsTo is an element of tv->getLogicalDomain()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I am not sure what is meant by "dominate" in the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re dominate: https://en.wikipedia.org/wiki/Dominator_(graph_theory) and I extended the concept to a set of nodes dominating another set.
Re the check: I heard from @naoyam that logical won't always dominate allocation with "the new indexing system".
@@ -196,7 +196,7 @@ void lowerToReduceScatter( | |||
std::vector<Communication*>& comms) { | |||
const DeviceMesh& mesh = input_tv->getDeviceMesh(); | |||
auto reduction_axis = output_tv->getReductionAxis().value(); | |||
auto scattered_axis = getShardedAxis(output_tv); | |||
auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, however if the sharded dimension is split, then scatted_axis
is not valid here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of an immediate problem and #3504 apparently works fine. Could be incidental and I'm happy to hear what you think is problematic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I think this is an example where we see the problem:
d=num_devices;
tv0 [d, i1];
tv1 = sum(tv0, axis=0); // tv1 [r{i0}, i1]
tv0->axis(0)->parallelize(DIDx);
tv1->axis(1)->split(d); // [r{i0}, i1/4, d]
tv1->axis(2)->parallelize(DIDx)
In this case, the scattered axis is 2 but getShardedAxis
returns 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your case,
tv0:
logical: [iDID{i0}, i{i1}]
tv1:
logical: [r{i0}, i{i1}]
allocation: [r{i0}, i{i1/d}, iDID{d}]
getShardedLogicalAxis will return 0, the tensor axis being sharded. This is correct because the output at::Tensor for tv1
will be of shape [i1/d] and indeed axis 0 is the sharded dimension. Then, scattered_axis=0 will be used to compute which input tensor axis will be sharded (which will be 1). Finally, that input scattered axis (1) will be used to split the input tensor of shape [1, i1].
Caveat: With 7cf2384, DID'ing an inner split is disallowed by code. So the above case will actually throw an exception. But what I said should be correct after we lift that limitation.
at::randn({num_devices * kTensorSize}, at::kFloat); | ||
at::Tensor in_tensor = | ||
shardTensor(unsharded_tensor, in).to(communicator_->device()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int64_t> ref_in_tensor_shape = {kTensorSize}; | |
EXPECT_EQ(in_tensor.sizes(), ref_in_tensor_shape); | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how shardTensor can be correct here if it never replays the split backwards... But I might be missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! I think there are two problems with the PR as is:
shardTensor
may slice wrong numbers. For example, if an inner split is DID'ed, the slicing needs to be strided per the outer split.- nvFuser doesn't error out when Allgather is not along the outermost allocated dimension. This was guaranteed by ReorderShardedAxisPass by checking isInnerResharding. However, getShardingChanges, one of its dependencies, hasn't been updated to read loop/allocation:
Fuser/csrc/multidevice/utils.cpp
Line 77 in 67127c9
auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re the suggested change: I manually checked the shape is as expected. I added some extra unit tests for shardTensor alone, so we don't have to verify it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a couple of changes to address the problems I said in #3284 (comment).
- 7cf2384. It's an overkill but will probably be OK for quite some time. I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension. I agree that to properly support inner splits we'll need to "replay the split backwards". It's not a trivial change anyhow so I'll postpone it to a separate PR.
- I wrote Harden assertBuffersHaveSameSize to check shapes. #3531 to harden runtime checks for allgather and added to this PR one more allgather test (Allgather_LoopSplit_Noncontiguous). These extra checks will fire when we trigger some most common limitations before properly fixing ReorderShardedAxisPass, which will take several decent-size PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension.
In fact, there's
Fuser/tests/cpp/test_multidevice_overlap.cpp
Line 681 in 64bc560
// A has shape (S, sharded(D), M/(S*D), K) |
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; | ||
assertIsCompiledToHostIrContainer(fec); | ||
|
||
EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use validate
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed allgather's lowering was not changed...I'm a bit surprised it didn't need any modifications for inputs with DID loop split! I might have missed a few earlier PRs though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use
validate
here?
Since validate
allows for (small) differences, if two tensors are supposed to be exactly the same, just using the simpler validation method, i.e., at::equal
, would be more preferable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised it didn't need any modifications for inputs with DID loop split!
Whether we call lowerToAllGather depends on I/O meshes and whether I/O is sharded:
lowerToAllgather(input_tv, output_tv, comms); |
That being said, I think this PR as is is a bit too permissive and may lower a set
to Allgather
without properly checking its allocation domain. For example,
Fuser/csrc/multidevice/utils.cpp
Line 77 in 67127c9
auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That being said, I think this PR as is is a bit too permissive and may lower a set to Allgather without properly checking its allocation domain.
I tried to address this in #3284 (comment).
I see, thanks for the clarifications. Is it still ready for review or would you like to fix the points you mentioned above? On my side, I understand why the PR works as is, but I am a little bit concerned because, in my opinion, it relies on a weak contract. To be more precise, the tests only work here because, when splitting the innermost loop axis and sharding one split, the allgather will still operate on innermost dimension (index 0), which has the same index as the logical axis used to produce that loop axis. The fact that this works feels a bit incidental and not robust, in particular it should fail for any variants like, merging axis, transposition (even on non-sharded axis), multidimensional Parallelization (e.g. DIDy), etc. Wdyt? |
I'll try to fix these points. I'd like nvFuser to at least error out on cases it doesn't support. This way, when we trigger a known limitation in the future, it'll show up as an exception instead of silently generating wrong numbers. |
I wrote this to make the allgather-related issue discovered in #3284 (comment) easier to expose. And it seems a good runtime check to have in extra, because `_allgather_base` treats I/O tensors as flat buffers and ignores the shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!test |
Another baby step towards #2563