Skip to content

Commit

Permalink
ReduceScatter with DID loop split
Browse files Browse the repository at this point in the history
For #2563
  • Loading branch information
wujingyue committed Dec 10, 2024
1 parent 98352c4 commit 016eb60
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
5 changes: 3 additions & 2 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,12 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
}
#endif
std::vector<std::vector<at::Tensor>> input_tensors(1);
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);
input_tensors[0] = at::tensor_split(
input_tensor, communication->team_size(), scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
assertBuffersHaveSameSize(input_tensors[0], {});
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
}
Expand Down
68 changes: 62 additions & 6 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

@pytest.mark.mpi
def test_allgather(mpi_test):
num_devices = mpi_test.size
mesh = nvfuser.DeviceMesh(range(num_devices))
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(num_devices * 4,), contiguity=True, dtype=DataType.Float
(d * 4,), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.set(self.inp)
self.add_output(self.out)
Expand All @@ -30,16 +30,72 @@ def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.split(self.inp, 0, num_devices, False)
self.sched.split(self.inp, 0, d, False)
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.inp)

self.sched.split(self.out, 0, num_devices, False)
self.sched.split(self.out, 0, d, False)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(num_devices * 4)
unsharded = torch.randn(d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded)


@pytest.mark.mpi
def test_allreduce(mpi_test):
d = mpi_test.size

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(d))
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

unsharded = torch.randn(d, 4)
sharded = mpi_test.shard_tensor(unsharded, 0)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_reduce_scatter(mpi_test):
d = mpi_test.size

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(d, d * 4), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(d))
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

self.sched.split(self.out, -1, d, False)
self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(d, d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 0))

0 comments on commit 016eb60

Please sign in to comment.