Skip to content

Commit

Permalink
Add the missing view operations from sequence parallel(async). (#6750)
Browse files Browse the repository at this point in the history
FYI @loadams 

a view operation was missing in some updates compared to the original
version
https://github.com/microsoft/DeepSpeed/blob/17ed7c77c58611a923a6c8d2a3d21d359cd046e8/deepspeed/sequence/layer.py#L56

add missing view operation.
The shape required for the view cannot be easily obtained in the current
function, so refactor layout params code.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 7f3d669 commit bc76b04
Showing 1 changed file with 70 additions and 59 deletions.
129 changes: 70 additions & 59 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,71 @@
from deepspeed.utils import groups


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)

post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]

return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape


def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""

def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()

return output

return post_func


def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t


def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
Expand Down Expand Up @@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
return res


def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):

def post_func(input):
if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
head_dim).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
return output

return post_func


def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
Expand Down Expand Up @@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
head_dim]).contiguous()
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
head_dim]).contiguous()
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)

if scatter_idx < 2:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
head_dim)
else:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
head_dim)
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)

post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

Expand All @@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output
return output.view(post_all2all_res_shape)

res = post_all2all_fun(output)
return res
Expand Down Expand Up @@ -271,7 +283,6 @@ def forward(ctx: Any,
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
del ctx.stream.activation_buffer_list
# The computation of d o_weight can overlap with the communication of d o_input

elif not is_fwd and type in ('q', 'k'):
Expand Down

0 comments on commit bc76b04

Please sign in to comment.