Skip to content

Commit

Permalink
use send/recv implement all2all
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Dec 30, 2024
1 parent 6f732b1 commit 3d94dd5
Showing 1 changed file with 64 additions and 69 deletions.
133 changes: 64 additions & 69 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,33 +1669,46 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
std::vector<size_t> sendCounts(size_);
std::vector<size_t> recvCounts(size_);
bool inputSplitsEqual = inputSplitSizes.size() == 0;
bool outputSplitsEqual = outputSplitSizes.size() == 0;

size_t inLen = input.numel();
size_t outLen = output.numel();
if (inLen)
inLen /= (inputSplitsEqual ? size_ : input.size(0));
if (outLen)
outLen /= (outputSplitsEqual ? size_ : output.size(0));

for (int i = 0; i < size_; i++) {
sendCounts[i] =
(inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen);
recvCounts[i] =
(outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen);
}
std::vector<size_t> send_lengths(size_);
std::vector<size_t> recv_lengths(size_);
std::vector<size_t> send_offsets(size_);
std::vector<size_t> recv_offsets(size_);

c10d::computeLengthsAndOffsets(
inputSplitSizes, input, &send_lengths, &send_offsets);
c10d::computeLengthsAndOffsets(
outputSplitSizes, output, &recv_lengths, &recv_offsets);

c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);

auto xcclDataType = getXcclDataType(output.scalar_type());
ccl::alltoallv(
input.data_ptr(),
sendCounts,
output.data_ptr(),
recvCounts,
xcclDataType,
comm,
ccl::create_stream(stream.queue()));
size_t size = input.element_size();

ccl:
group_start();
for (const auto r : c10::irange(size_)) {
if (send_lengths[r] != 0) {
ccl::send(
input.data_ptr() + send_offsets[r] * size,
send_lengths[r],
xcclDataType,
r,
comm,
ccl::create_stream(stream.queue()));
}
if (recv_lengths[r] != 0) {
ccl::recv(
output.data_ptr() + recv_offsets[r] * size,
recv_lengths[r],
xcclDataType,
r,
comm,
ccl::create_stream(stream.queue()));
}
}
ccl::group_end();

return;
},
OpType::ALLTOALL_BASE,
Expand Down Expand Up @@ -1743,52 +1756,34 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
at::Tensor& /* unused */,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
c10::OptionalStreamGuard stream_guard(stream.unwrap());
at::Tensor flatInput;
at::Tensor flatOutput;

std::vector<size_t> sendCounts(size_);
std::vector<size_t> recvCounts(size_);

int64_t flatSendCount;
int64_t flatRecvCount;

bool isInputFlat = computeLengthsAndCheckAndGetFlat(
inputTensors, sendCounts, flatInput, flatSendCount);
bool isOutputFlat = computeLengthsAndCheckAndGetFlat(
outputTensors, recvCounts, flatOutput, flatRecvCount);
if (!isInputFlat) {
auto flatInputSplits = flatInput.split_with_sizes(
c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()),
0);

for (int i = 0; i < size_; i++) {
flatInputSplits[i].copy_(inputTensors[i].view({-1}));
}
}

auto xcclDataType = getXcclDataType(flatOutput.scalar_type());
ccl::event ret_evt;
ret_evt = ccl::alltoallv(
flatInput.data_ptr(),
sendCounts,
flatOutput.data_ptr(),
recvCounts,
xcclDataType,
comm,
ccl::create_stream(stream.queue()));

if (!isOutputFlat) {
ret_evt.wait();
auto flatOutputSplits = flatOutput.split_with_sizes(
c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()),
0);
auto xcclDataType = getXcclDataType(input.scalar_type());

for (int i = 0; i < size_; i++) {
outputTensors[i].view({-1}).copy_(flatOutputSplits[i]);
ccl::group_start();
for (const int r :
c10::irange(static_cast<int>(outputTensors.size()))) {
at::Tensor& input = inputTensors[r];
at::Tensor& output = outputTensors[r];
if (input.numel() != 0) {
ccl::send(
input.data_ptr(),
input.numel(),
xcclDataType,
r,
comm,
ccl::create_stream(stream.queue()));
}
if (output.numel() != 0) {
ccl::recv(
output.data_ptr(),
output.numel(),
xcclDataType,
r,
comm,
ccl::create_stream(stream.queue()));
}
}
stream.synchronize();
ccl::group_end();

return;
},
OpType::ALLTOALL,
Expand Down

0 comments on commit 3d94dd5

Please sign in to comment.