diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 3dfe149057e9..ccc43fdf7164 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -184,7 +184,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn # we only need num_heads once num_heads = input.shape[2] - if get_num_kv_heads() is not None or num_heads % seq_world_size != 0: + if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2): # Assuming here that the number of heads for q is consistent with kv # If not, additional logic is required for cases like GQA if get_num_kv_heads() is None: