Skip to content

Commit

Permalink
Merge branch 'main' into jax/test/multiprocess_encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
phu0ngng authored Jan 11, 2025
2 parents e092780 + 7b861e7 commit a7202ff
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 468 deletions.
26 changes: 10 additions & 16 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,10 @@ def run_dpa_with_cp(
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
Expand Down Expand Up @@ -204,10 +202,8 @@ def run_dpa_with_cp(
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -276,10 +272,8 @@ def run_dpa_with_cp(
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -311,7 +305,7 @@ def run_dpa_with_cp(
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
Expand All @@ -327,7 +321,7 @@ def run_dpa_with_cp(
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
Expand Down
10 changes: 2 additions & 8 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
Expand All @@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
Expand Down
17 changes: 9 additions & 8 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct ReadLseFunctor {

template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int total_tokens) {
int num_heads, int lse_seqlen, int second_half_lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
Expand All @@ -85,15 +85,15 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * second_half_lse_seqlen + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];

idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
idx = row * lse_seqlen + col + seq_len;
half_idx = row * second_half_lse_seqlen + col;
}

Functor::run(lse, half_lse, idx, half_idx);
Expand All @@ -108,7 +108,8 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
int num_heads, int dim_per_head, int lse_seqlen,
int lse_per_step_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
Expand All @@ -128,13 +129,13 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float

if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
idx_per_step = head_id * lse_per_step_seqlen + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
idx_per_step = row * lse_per_step_seqlen + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);

Expand Down
Loading

0 comments on commit a7202ff

Please sign in to comment.