Skip to content

Commit

Permalink
ReduceScatter with DID loop split (#3504)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Dec 11, 2024
1 parent 8c82f30 commit d178c2a
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 23 deletions.
33 changes: 21 additions & 12 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,27 +429,36 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
scattered_axis >= 0,
"scattered_axis is expected to be non-negative: ",
scattered_axis);
// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.

std::vector<at::Tensor> input_tensors = at::tensor_split(
input_tensor, communication->team_size(), scattered_axis);
// We could have checked the output shape as well if reduction_axis is
// available. It's not always available via
// `communication->out()->getReductionAxis()` for manually constructed host
// IRs like
// https://github.com/NVIDIA/Fuser/blob/89c47f695b296eb4ffd27984bd4c953fc3f3264b/tests/cpp/test_multidevice_overlap.cpp#L347.
assertBuffersHaveSameSize(input_tensors, {});

// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
if (scattered_axis == 0 &&
backend->getBackendName() == c10d::NCCL_BACKEND_NAME) {
return backend->_reduce_scatter_base(
output_tensor, input_tensor, {.reduceOp = communication->reduceOp()});
}
#endif
std::vector<std::vector<at::Tensor>> input_tensors(1);
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
std::vector<std::vector<at::Tensor>> input_tensors_vec({input_tensors});
std::vector<at::Tensor> output_tensor_vec({output_tensor});
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
output_tensor_vec,
input_tensors_vec,
{.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postSendRecv(
Expand Down
19 changes: 14 additions & 5 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on

#include <gmock/gmock-matchers.h>
#include <gmock/gmock-more-matchers.h>
#include <gtest/gtest.h>

#include <ops/all_ops.h>
Expand All @@ -16,15 +17,23 @@

namespace nvfuser {

using testing::Each;
using testing::IsTrue;
using testing::Pointer;
using testing::Property;

namespace {
void assertIsCompiledToHostIrContainer(
const FusionExecutorCache& executor_cache) {
FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
EXPECT_TRUE(runtime->executors().size() == 1);
for (const auto& ea : runtime->executors()) {
EXPECT_TRUE(ea->isA<HostIrExecutor>())
<< "failed to compile to a HostIrContainer with Communications";
}
EXPECT_EQ(runtime->executors().size(), 1);
EXPECT_THAT(
runtime->executors(),
Each(Pointer(Property(
"is a HostIrExecutor",
&ExecutorAbstract::isA<HostIrExecutor>,
IsTrue()))))
<< "failed to compile to a HostIrContainer with Communications";
}
} // namespace

Expand Down
112 changes: 106 additions & 6 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

@pytest.mark.mpi
def test_allgather(mpi_test):
num_devices = mpi_test.size
mesh = nvfuser.DeviceMesh(range(num_devices))
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(num_devices * 4,), contiguity=True, dtype=DataType.Float
(d * 4,), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.set(self.inp)
self.add_output(self.out)
Expand All @@ -30,16 +30,116 @@ def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.split(self.inp, 0, num_devices, False)
self.sched.split(self.inp, 0, d, False)
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.inp)

self.sched.split(self.out, 0, num_devices, False)
self.sched.split(self.out, 0, d, False)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(num_devices * 4)
unsharded = torch.randn(d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded)


@pytest.mark.mpi
def test_allreduce(mpi_test):
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

unsharded = torch.randn(d, 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_reduce_scatter(mpi_test):
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(d, d * 4), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

self.sched.split(self.out, -1, d, False)
self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(d, d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(
outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 0, mesh)
)


@pytest.mark.mpi
def test_reduce_scatter_noncontiguous(mpi_test):
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(d, 3, d * 4), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

# inp: [iDID{d}, i{3}, i{d*4}]
# out: [r{d}, i{3}, i{d*4}]
# / \
# iDID{d} i{4}
#
# Unlike test_reduce_scatter, this leads to extra data copy because
# the scattered axis is not outermost in allocation.
# ProcessGroupNCCL::reduce_scatter was able to handle
# non-contiguous scattering in a functional but suboptimal way.
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

self.sched.split(self.out, -1, d, False)
self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(d, 3, d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(
outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 1, mesh)
)

0 comments on commit d178c2a

Please sign in to comment.