From 70a1e8571802c22c0f09279b77876e6e85c81325 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 10 Jun 2024 20:19:35 -0700 Subject: [PATCH] [Traceable FSDP2] Use custom ops for AllGather copy-in / copy-out and ReduceScatter copy-in (#127856) Making these operations into custom ops helps Inductor identify these ops and enforce the FSDP communication op ordering. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127856 Approved by: https://github.com/awgu --- .../_composable/fsdp/_fsdp_collectives.py | 111 ++++++++++++++++-- 1 file changed, 102 insertions(+), 9 deletions(-) diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index ac5084813ee164..1423cfd600fc88 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -26,6 +26,98 @@ class AllGatherResult(NamedTuple): all_gather_input_split_sizes: List[int] +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device + ) -> (Tensor, Tensor) + """ +) + + +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@torch.library.impl(lib, "all_gather_copy_in", "CUDA") +def all_gather_copy_in_cuda( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@torch.library.impl(lib, "split_with_sizes_copy", "Meta") +@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[torch.Tensor], +) -> None: + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@torch.library.impl(lib, "chunk_cat", "Meta") +@torch.library.impl(lib, "chunk_cat", "CUDA") +def chunk_cat( + tensors: List[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + @torch.no_grad() def foreach_all_gather( fsdp_params: List[FSDPParam], @@ -53,14 +145,15 @@ def foreach_all_gather( all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts] inp_split_sizes = [t.numel() for t in all_gather_inputs] all_gather_input_numel = sum(inp_split_sizes) - all_gather_output = torch.empty( - (all_gather_input_numel * world_size,), dtype=dtype, device=device - ) - all_gather_input = all_gather_output.narrow( - 0, all_gather_input_numel * rank, all_gather_input_numel + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, ) - foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) - torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with torch.cuda.stream(all_gather_stream): @@ -124,7 +217,7 @@ def foreach_all_gather_copy_out( out = [t.view(world_size, -1).view(torch.uint8) for t in gen] else: out = [t.view(world_size, -1) for t in gen] - torch.split_with_sizes_copy( + torch.ops.fsdp.split_with_sizes_copy( all_gather_output, all_gather_input_split_sizes, dim=1, out=out ) @@ -259,7 +352,7 @@ def foreach_reduce_scatter_copy_in( world_size: int, ) -> None: reduce_scatter_input = reduce_scatter_input.view(world_size, -1) - torch._chunk_cat( + torch.ops.fsdp.chunk_cat( unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input )