Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
ischlag committed Jul 30, 2024
2 parents 793bdf3 + 2e48d66 commit ddc8fa1
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 27 deletions.
74 changes: 54 additions & 20 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ def backward(ctx, grad_output):
group = ctx.group
use_bias = ctx.use_bias

handle_0: Optional[dist.Work] = None
handle_1: Optional[dist.Work] = None
handle: Optional[dist.Work] = None

# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = grad_output.shape
Expand All @@ -412,31 +411,69 @@ def backward(ctx, grad_output):
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()

handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)

grad_tensor = grad_output.matmul(weight)
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)

# wait for the first all_gather to finish before starting the second all_gather
if handle_0 is not None:
handle_0.wait()

# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = grad_tensor.shape
# total_grad_output: [b, s, h_out]
# weight: [h_out, h_in/n]
# total_grad_tensor: [b, s, h_in/n]
# grad_output: [b/n, s, h_out]
sharded_batch_size, *rest_size_grad_output = grad_output.shape
rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]

if group.size() == 1:
total_grad_tensor = grad_tensor
total_grad_tensor = grad_output.matmul(weight)
else:
unsharded_batch_size = sharded_batch_size * group.size()

total_grad_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=grad_tensor.device,
dtype=grad_tensor.dtype,
*rest_size_grad_tensor,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
)
before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
total_grad_tensor,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# compute local shard
torch.mm(
input=grad_output.view(-1, grad_output.shape[-1]),
mat2=weight,
out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
)

handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True)
if handle is not None:
handle.wait()

before_shard_grad_output, _, after_shard_grad_output = torch.split(
total_grad_output,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)

# before shard compute
if before_shard_grad_tensor.numel() > 0:
torch.mm(
input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
mat2=weight,
out=before_shard_grad_tensor.view(-1, weight.shape[1]),
)
# after shard compute
if after_shard_grad_tensor.numel() > 0:
torch.mm(
input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
mat2=weight,
out=after_shard_grad_tensor.view(-1, weight.shape[1]),
)

# Convert the tensor shapes to 2D for execution compatibility
tensor = tensor.contiguous()
Expand All @@ -454,9 +491,6 @@ def backward(ctx, grad_output):
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None

if handle_1 is not None:
handle_1.wait()

return total_grad_tensor, grad_weight, grad_bias, None, None


Expand Down
29 changes: 22 additions & 7 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL
random_input = torch.randn(batch_size, in_features, device="cuda")
# synchronize random_input across tp
dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)

random_input.requires_grad = True
# Row linear receives as input sharded input
random_sharded_input = random_input[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
random_sharded_input = (
random_input[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
.detach()
.clone()
)
random_sharded_input.requires_grad = True

# Test that we get the same output after forward pass
# TODO @kunhao: We may want to have our custom error type
Expand Down Expand Up @@ -261,6 +266,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL
else:
assert row_linear.bias is None

torch.testing.assert_close(
random_sharded_input.grad,
random_input.grad[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
],
)

parallel_context.destroy()


Expand Down

0 comments on commit ddc8fa1

Please sign in to comment.