Skip to content

Commit

Permalink
[Traceable FSDP2] Use custom ops for AllGather copy-in / copy-out and…
Browse files Browse the repository at this point in the history
… ReduceScatter copy-in (pytorch#127856)

Making these operations into custom ops helps Inductor identify these ops and enforce the FSDP communication op ordering.

Pull Request resolved: pytorch#127856
Approved by: https://github.com/awgu
  • Loading branch information
yf225 authored and pytorchmergebot committed Jun 11, 2024
1 parent adb6991 commit 70a1e85
Showing 1 changed file with 102 additions and 9 deletions.
111 changes: 102 additions & 9 deletions torch/distributed/_composable/fsdp/_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,98 @@ class AllGatherResult(NamedTuple):
all_gather_input_split_sizes: List[int]


lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901

lib.define(
"""
all_gather_copy_in(
Tensor[] all_gather_inputs,
SymInt[] inp_split_sizes,
SymInt all_gather_input_numel,
SymInt world_size,
SymInt rank,
ScalarType dtype,
Device device
) -> (Tensor, Tensor)
"""
)


@torch.library.impl(lib, "all_gather_copy_in", "Meta")
def all_gather_copy_in_meta(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
return all_gather_input, all_gather_output


@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
def all_gather_copy_in_cuda(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output


lib.define(
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
)


@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
dim: int,
out: List[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)


lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)


@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CUDA")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
torch._chunk_cat(tensors, dim, num_chunks, out=out)


@torch.no_grad()
def foreach_all_gather(
fsdp_params: List[FSDPParam],
Expand Down Expand Up @@ -53,14 +145,15 @@ def foreach_all_gather(
all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts]
inp_split_sizes = [t.numel() for t in all_gather_inputs]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
all_gather_inputs,
inp_split_sizes,
all_gather_input_numel,
world_size,
rank,
dtype,
device,
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
del param_all_gather_inputs
all_gather_stream.wait_stream(all_gather_copy_in_stream)
with torch.cuda.stream(all_gather_stream):
Expand Down Expand Up @@ -124,7 +217,7 @@ def foreach_all_gather_copy_out(
out = [t.view(world_size, -1).view(torch.uint8) for t in gen]
else:
out = [t.view(world_size, -1) for t in gen]
torch.split_with_sizes_copy(
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)

Expand Down Expand Up @@ -259,7 +352,7 @@ def foreach_reduce_scatter_copy_in(
world_size: int,
) -> None:
reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
torch._chunk_cat(
torch.ops.fsdp.chunk_cat(
unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
)

Expand Down

0 comments on commit 70a1e85

Please sign in to comment.