Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReduceScatter with DID loop split #3504

Merged
merged 6 commits into from
Dec 11, 2024
Merged

ReduceScatter with DID loop split #3504

merged 6 commits into from
Dec 11, 2024

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Dec 2, 2024

For #2563

Tested: http://nv/eoZ



@pytest.mark.mpi
def test_allreduce(mpi_test):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allreduce test is merely for DID logical split. I don't think allreduce can support DID loop split because sum's reduction axes can only be logical. But I'd be happy to know otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allreduce test is merely for DID logical split.

Just to be clear, you meant DID parallelization of logical domains, right? I'm not sure what you meant by DID logical split otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming I understand what you meant correctly, I think this is where TensorView::rFactor could be used. That's what we use for intra-device hierarchical reductions. For example, I'd think that for multi-GPU reductions, we would have something like:

(I'm mixing the C++ and Python APIs)

self.out->split(0, num_devices, /*inner=*/false);
auto intermediate_result = self.out->rFactor({1});
intermediate_result->axis(0)->parallelize(DIDx);
self.out->axis(0)->parallelize(DIDx);

Here, intermediate_result would be the partial result of per-device reduction, which would be then reduced between all the devices and saved to self.out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's something like rfactor, and did look into how TensorView::rfactor works in

TensorView* tv2 = tv1_copy->rFactor({0});
. However, I failed to see how it applies here.

If we want to loop (but not logical) split an allreduce, the input would be a logical shape like [D*2,3] and the output would be of logical shape like [2,3]. Regardless of scheduling, what ops in fusion IR could do that? (Not a sum because that reduces an entire dimension to 1).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk offline. It seems we are not using the same vocabulary (e.g., I don't understand what "loop split" and "logical split" mean).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3543 is my failed attempt. It triggered an assertion at

NVF_THROW("Unexpected producer RF ID: ", producer_rf_id->toString())
. Is it because code there has been assuming that reductions are innermost?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyhow, this isn't a blocker. As we discussed yesterday, we'll probably stick with logical split for reductions in Allreduce and ReduceScatter due to MatmulOp's implementation.

@wujingyue
Copy link
Collaborator Author

!test

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from wjy/split to main December 10, 2024 15:32
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!build

@wujingyue wujingyue merged commit d178c2a into main Dec 11, 2024
12 of 13 checks passed
@wujingyue wujingyue deleted the wjy/rs branch December 11, 2024 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants