diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 42081a64..e2ee3a29 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -345,7 +345,7 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") -class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """ Column linear with memory_buffer for the allgather, context parallel enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and @@ -416,6 +416,9 @@ def backward(ctx, grad_output: torch.Tensor): if group.size() == 1: sub_grad_input = grad_input else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() sub_grad_input = torch.empty( input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False ) @@ -439,12 +442,12 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function):