Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allgather with DID loop split #3284

Merged
merged 23 commits into from
Dec 9, 2024
Merged

Allgather with DID loop split #3284

merged 23 commits into from
Dec 9, 2024

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Oct 25, 2024

Another baby step towards #2563

@wujingyue wujingyue force-pushed the wjy/comm branch 2 times, most recently from 3c5c68e to 5ef8ff4 Compare November 18, 2024 18:50
@wujingyue wujingyue changed the base branch from main to wjy/forward November 27, 2024 19:42
@wujingyue wujingyue changed the title Fix communication lowering to support DID loop parallelization. Allgather with DID loop split Nov 27, 2024
@wujingyue wujingyue marked this pull request as ready for review November 28, 2024 00:08
csrc/multidevice/utils.cpp Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Show resolved Hide resolved
tests/cpp/multidevice.cpp Show resolved Hide resolved
@wujingyue
Copy link
Collaborator Author

!test

Base automatically changed from wjy/forward to main November 30, 2024 05:28
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

cc @xwang233 the testing infra appears to be problematic for H100: https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/125617985

Copy link
Collaborator

@samnordmann samnordmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this pr! I am still trying to fully understand how the logic works, but let me post a series of minor comments in the meantime.

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved

const auto iter = std::find(
tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]);
NVF_ERROR(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure to understand why this check is needed. Isn't it true that by assumption what is returned by getInputsTo is an element of tv->getLogicalDomain()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure what is meant by "dominate" in the error message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re dominate: https://en.wikipedia.org/wiki/Dominator_(graph_theory) and I extended the concept to a set of nodes dominating another set.

Re the check: I heard from @naoyam that logical won't always dominate allocation with "the new indexing system".

csrc/multidevice/utils.h Outdated Show resolved Hide resolved
csrc/multidevice/communication.cpp Show resolved Hide resolved
@@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv);
auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, however if the sharded dimension is split, then scatted_axis is not valid here, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of an immediate problem and #3504 apparently works fine. Could be incidental and I'm happy to hear what you think is problematic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but I think this is an example where we see the problem:

d=num_devices;

tv0 [d, i1];
tv1 = sum(tv0, axis=0); // tv1 [r{i0}, i1]

tv0->axis(0)->parallelize(DIDx);
tv1->axis(1)->split(d); // [r{i0}, i1/4, d]
tv1->axis(2)->parallelize(DIDx) 

In this case, the scattered axis is 2 but getShardedAxis returns 1.

Copy link
Collaborator Author

@wujingyue wujingyue Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your case,

tv0:
  logical: [iDID{i0}, i{i1}]
tv1:
  logical: [r{i0}, i{i1}]
  allocation: [r{i0}, i{i1/d}, iDID{d}]

getShardedLogicalAxis will return 0, the tensor axis being sharded. This is correct because the output at::Tensor for tv1 will be of shape [i1/d] and indeed axis 0 is the sharded dimension. Then, scattered_axis=0 will be used to compute which input tensor axis will be sharded (which will be 1). Finally, that input scattered axis (1) will be used to split the input tensor of shape [1, i1].

Caveat: With 7cf2384, DID'ing an inner split is disallowed by code. So the above case will actually throw an exception. But what I said should be correct after we lift that limitation.

at::randn({num_devices * kTensorSize}, at::kFloat);
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

Copy link
Collaborator

@samnordmann samnordmann Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<int64_t> ref_in_tensor_shape = {kTensorSize};
EXPECT_EQ(in_tensor.sizes(), ref_in_tensor_shape);

Copy link
Collaborator

@samnordmann samnordmann Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how shardTensor can be correct here if it never replays the split backwards... But I might be missing something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I think there are two problems with the PR as is:

  1. shardTensor may slice wrong numbers. For example, if an inner split is DID'ed, the slicing needs to be strided per the outer split.
  2. nvFuser doesn't error out when Allgather is not along the outermost allocated dimension. This was guaranteed by ReorderShardedAxisPass by checking isInnerResharding. However, getShardingChanges, one of its dependencies, hasn't been updated to read loop/allocation:
    auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the suggested change: I manually checked the shape is as expected. I added some extra unit tests for shardTensor alone, so we don't have to verify it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a couple of changes to address the problems I said in #3284 (comment).

  1. 7cf2384. It's an overkill but will probably be OK for quite some time. I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension. I agree that to properly support inner splits we'll need to "replay the split backwards". It's not a trivial change anyhow so I'll postpone it to a separate PR.
  2. I wrote Harden assertBuffersHaveSameSize to check shapes. #3531 to harden runtime checks for allgather and added to this PR one more allgather test (Allgather_LoopSplit_Noncontiguous). These extra checks will fire when we trigger some most common limitations before properly fixing ReorderShardedAxisPass, which will take several decent-size PRs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension.

In fact, there's

// A has shape (S, sharded(D), M/(S*D), K)
. So I'll try to file a feature request after this PR.

tests/cpp/multidevice.cpp Show resolved Hide resolved
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
assertIsCompiledToHostIrContainer(fec);

EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use validate here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed allgather's lowering was not changed...I'm a bit surprised it didn't need any modifications for inputs with DID loop split! I might have missed a few earlier PRs though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use validate here?

Since validate allows for (small) differences, if two tensors are supposed to be exactly the same, just using the simpler validation method, i.e., at::equal, would be more preferable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised it didn't need any modifications for inputs with DID loop split!

Whether we call lowerToAllGather depends on I/O meshes and whether I/O is sharded:

lowerToAllgather(input_tv, output_tv, comms);
. isSharded have been reading the allocation domain ince #3444.

That being said, I think this PR as is is a bit too permissive and may lower a set to Allgather without properly checking its allocation domain. For example,

auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
reads root and logical and needs to be updated. I'll try to fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I think this PR as is is a bit too permissive and may lower a set to Allgather without properly checking its allocation domain.

I tried to address this in #3284 (comment).

@samnordmann
Copy link
Collaborator

Thanks for the review! I think there are two problems with the PR as is:

  1. shardTensor may slice wrong numbers. For example, if an inner split is DID'ed, the slicing needs to be strided per the outer split.
  2. nvFuser doesn't error out when Allgather is not along the outermost allocated dimension. This was guaranteed by ReorderShardedAxisPass by checking isInnerResharding. However, getShardingChanges, one of its dependencies, hasn't been updated to read loop/allocation:
    auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);

I see, thanks for the clarifications. Is it still ready for review or would you like to fix the points you mentioned above?

On my side, I understand why the PR works as is, but I am a little bit concerned because, in my opinion, it relies on a weak contract. To be more precise, the tests only work here because, when splitting the innermost loop axis and sharding one split, the allgather will still operate on innermost dimension (index 0), which has the same index as the logical axis used to produce that loop axis. The fact that this works feels a bit incidental and not robust, in particular it should fail for any variants like, merging axis, transposition (even on non-sharded axis), multidimensional Parallelization (e.g. DIDy), etc. Wdyt?

@wujingyue
Copy link
Collaborator Author

Is it still ready for review or would you like to fix the points you mentioned above?

I'll try to fix these points. I'd like nvFuser to at least error out on cases it doesn't support. This way, when we trigger a known limitation in the future, it'll show up as an exception instead of silently generating wrong numbers.

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.h Outdated Show resolved Hide resolved
@wujingyue wujingyue requested a review from naoyam December 5, 2024 17:39
wujingyue added a commit that referenced this pull request Dec 6, 2024
I wrote this to make the allgather-related issue discovered in
#3284 (comment) easier
to expose. And it seems a good runtime check to have in extra, because
`_allgather_base` treats I/O tensors as flat buffers and ignores the
shapes.
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue merged commit 4a897a4 into main Dec 9, 2024
34 of 35 checks passed
@wujingyue wujingyue deleted the wjy/comm branch December 9, 2024 23:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants