Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the missing view operations from sequence parallel(async). #6750

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
129 changes: 70 additions & 59 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,71 @@
from deepspeed.utils import groups


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)

post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]

return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape


def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""

def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()

return output

return post_func


def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t


def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
Expand Down Expand Up @@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
return res


def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):

def post_func(input):
if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
head_dim).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
return output

return post_func


def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
Expand Down Expand Up @@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
head_dim]).contiguous()
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
head_dim]).contiguous()
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)

if scatter_idx < 2:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
head_dim)
else:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
head_dim)
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)

post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

Expand All @@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output
return output.view(post_all2all_res_shape)

res = post_all2all_fun(output)
return res
Expand Down Expand Up @@ -271,7 +283,6 @@ def forward(ctx: Any,
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
del ctx.stream.activation_buffer_list
# The computation of d o_weight can overlap with the communication of d o_input

elif not is_fwd and type in ('q', 'k'):
Expand Down