Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Dec 30, 2024
1 parent 2a80dce commit 106adb5
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kFloat8_e5m2fnuz, ccl::datatype::uint8},
};

bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
bool checkSameSize(const std::vector<at::Tensor>& input_tensors) {
for (const auto& input_tensor : input_tensors) {
if (!input_tensors[0].is_same_size(input_tensor)) {
return false;
Expand Down Expand Up @@ -109,7 +109,8 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
// Map sum to max for bool tensors to avoid overflow issues with sum.
return ccl::reduction::max;
}
// Use SUM emu AVG due to oneCCL not support AVG
// Use SUM emu AVG due to oneCCL not support AVG.
// oneCCL is expected to support avg in basekit 2025.2 release.
if (reduceOp == ReduceOp::AVG) {
return ccl::reduction::sum;
}
Expand Down Expand Up @@ -454,6 +455,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -507,6 +509,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -558,6 +561,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -676,6 +680,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG && getRank() == root) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -715,7 +720,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather(
-1, // globalRankStride
this->getSize()); // worldSize

bool same_size = check_same_size(outputTensors_);
bool same_size = checkSameSize(outputTensors_);
if (same_size) {
// Flatten a vector of tensors into a single, stacked tensor.
at::Tensor outputFlattened = newLikeFlat(outputTensors_);
Expand Down Expand Up @@ -877,7 +882,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
-1, // globalRankStride
this->getSize()); // worldSize

bool same_size = check_same_size(inputTensors_);
bool same_size = checkSameSize(inputTensors_);
if (same_size) {
// Flatten a vector of tensors into a single, stacked tensor.
at::Tensor inputFlattened = newLikeFlat(inputTensors_);
Expand All @@ -901,6 +906,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -989,6 +995,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down Expand Up @@ -1023,6 +1030,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
comm,
ccl::create_stream(stream.queue()));
// Use SUM emu AVG due to oneCCL not support AVG
// oneCCL is expected to support avg in basekit 2025.2 release.
if (opts.reduceOp == ReduceOp::AVG) {
auto divisor = getSize();
output.div_(divisor);
Expand Down

0 comments on commit 106adb5

Please sign in to comment.