diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 9efdebc17ad..9003ec09351 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -314,6 +314,17 @@ void FusionDefinition::finalizeSchedule( // Users can access schedule objects after scheduling the fusion. } +void FusionDefinition::setupMultideviceSchedule() { + // FusionDefinition.multidevice_schedule may create new Exprs (e.g. DID + // splits), which will be added to the presched fusion. + prev_fusion_ = FusionGuard::getCurFusion(); + FusionGuard::setCurFusion(preschedFusion()); +} + +void FusionDefinition::finalizeMultideviceSchedule() { + FusionGuard::setCurFusion(prev_fusion_); +} + void FusionDefinition::print(std::ostream& os) const { if (id().has_value()) { os << "\ndef nvfuser_fusion_id" << id().value(); diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index c359352c565..6157704f86b 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -184,6 +184,11 @@ class NVF_API FusionDefinition : public FusionState { //! Finalized use scheduling of a fusion //! resets FusionGuard, lowers IR to a kernel, compiles kernel NVF_API void finalizeSchedule(const at::ArrayRef& inputs); + //! A hook that gets called right before + //! FusionDefinition.multidevice_schedule. + NVF_API void setupMultideviceSchedule(); + //! A hook that gets called right after FusionDefinition.multidevice_schedule. + NVF_API void finalizeMultideviceSchedule(); //! Prints a python function representing the definition NVF_API void print(std::ostream& os) const; //! Executes a fusion if a valid definition or cache lookup occurred prior diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d72115b8043..ce597e6b1a3 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1091,6 +1091,12 @@ void initNvFuserPythonBindings(PyObject* module) { // Mark the end of a schedule inst::Trace::instance()->endEvent(nullptr); }) + .def( + "_setup_multidevice_schedule", + [](FusionDefinition& self) { self.setupMultideviceSchedule(); }) + .def( + "_finalize_multidevice_schedule", + [](FusionDefinition& self) { self.finalizeMultideviceSchedule(); }) .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) .def("extents", [](FusionDefinition& self) { return self.extents(); }) @@ -3596,7 +3602,6 @@ void initNvFuserPythonBindings(PyObject* module) { }, py::arg("tensor"), py::arg("mesh")); - //! experimental API for multidevice support nvf_sched.def( "parallelize", [](FusionDefinition::SchedOperators& self, @@ -3683,6 +3688,18 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("dim"), py::arg("factor"), py::arg("inner_split") = true); + nvf_sched.def( + "set_allocation_as_loop", + [](FusionDefinition::SchedOperators& self, Tensor arg) { + FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto* tv = fd->getFusionState(arg.index)->template as(); + tv->setAllocationDomain(tv->getLoopDomain(), true); + }, + py::arg("arg")); nvf_sched.def( "cache_after", [](FusionDefinition::SchedOperators& self, diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 967be681b7f..d65198c21a7 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -284,7 +284,9 @@ def execute( # # Note: there's a plan to embed multidevice schedules into FusionDefinition # as annotating nodes. This may eventually replace `multidevice_schedule`. + self._setup_multidevice_schedule() self.multidevice_schedule() + self._finalize_multidevice_schedule() # If schedule is defined by child class and schedule is not defined for # inputs, make a schedule. diff --git a/tests/python/mpi_fixtures.py b/tests/python/mpi_fixtures.py index d0f80f20a62..0915bc84f20 100644 --- a/tests/python/mpi_fixtures.py +++ b/tests/python/mpi_fixtures.py @@ -38,6 +38,14 @@ def local_rank(self): def barrier(self): self._communicator.barrier() + def shard_tensor(self, t: torch.Tensor, dim: int) -> torch.Tensor: + assert t.is_cpu, ( + "This is not strictly required but it's a general good practice " + "for unit tests to create unsharded data on CPU to reduce GPU " + "memory footprint." + ) + return t.tensor_split(self.size, dim)[self.rank].cuda(self.local_rank) + @pytest.fixture(scope="session") def mpi_test(): diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py new file mode 100644 index 00000000000..259b433c985 --- /dev/null +++ b/tests/python/test_communication.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import torch + +import mpi_fixtures +import nvfuser +from nvfuser import DataType, FusionDefinition + + +mpi_test = mpi_fixtures.mpi_test + + +@pytest.mark.mpi +def test_allgather(mpi_test): + num_devices = mpi_test.size + rank = mpi_test.rank + + unsharded = torch.randn(num_devices * 4) + sharded = mpi_test.shard_tensor(unsharded, 0) + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor( + (num_devices * 4,), contiguity=True, dtype=DataType.Float + ) + self.out = self.ops.set(self.inp) + self.add_output(self.out) + + def multidevice_schedule(self): + mesh = self.sched._create_device_mesh(range(num_devices)) + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + + self.sched.split(self.inp, 0, num_devices, False) + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(self.inp) + + self.sched.split(self.out, 0, num_devices, False) + self.sched.set_allocation_as_loop(self.out) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close(outputs[0].cpu(), unsharded) diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index a54353e38a5..06686c07143 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -35,10 +35,8 @@ def test_pointwise(mpi_test): num_devices = mpi_test.size rank = mpi_test.rank - torch.cuda.set_device(mpi_test.local_rank) - - unsharded_input = torch.randn(num_devices, 4, device="cuda") - sharded_input = unsharded_input[rank : rank + 1] + unsharded_input = torch.randn(num_devices, 4) + sharded_input = mpi_test.shard_tensor(unsharded_input, 0) class Model(FusionDefinition): def definition(self): @@ -58,7 +56,7 @@ def multidevice_schedule(self): fd = Model() outputs = fd.execute([sharded_input]) - torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2) + torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2) @pytest.mark.mpi