Skip to content

Commit

Permalink
Merge pull request #1 from AleHD/fix_tp_mem_cache
Browse files Browse the repository at this point in the history
Minor restyling
  • Loading branch information
AleHD authored Aug 2, 2024
2 parents 4c94b99 + 31c3c5a commit 0adb368
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def backward(ctx, grad_output):
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _ColumnLinearContextParallelNoAsync(torch.autograd.Function):
class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
Expand Down Expand Up @@ -416,6 +416,9 @@ def backward(ctx, grad_output: torch.Tensor):
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
Expand All @@ -439,12 +442,12 @@ def column_linear(

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

return F.linear(input, weight, bias)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _RowLinearAsyncCommunication(torch.autograd.Function):
Expand Down

0 comments on commit 0adb368

Please sign in to comment.