diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4fa2cc988a19..588462ffcc6f 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -19,6 +19,12 @@ _flash_attn_forward = None _flash_attn_backward = None +if get_accelerator().device_name() == 'xpu': + FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder") + flash_attn_builder = FlashAttentionBuilder().load() + flash_attn_func_fwd = flash_attn_builder.flash_attn_fwd + flash_attn_func_bwd = flash_attn_builder.flash_attn_bwd + from einops import rearrange from .layer import single_all_to_all, apply_rotary_pos_emb @@ -231,27 +237,41 @@ def forward(ctx: Any, for k_i in range(len(global_k)): causal_chunk = i == k_i - if flash_attn_version >= version.parse("2.6.0"): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], - global_k[k_i], - global_v[k_i], - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + if get_accelerator().device_name() == 'xpu': + # input should do transpose from b,l,n,d to b,n,l,d + block_out, block_lse, _, _ = flash_attn_func_fwd( + global_q[i].transpose(1, 2), + global_k[k_i].transpose(1, 2), + global_v[k_i].transpose(1, 2), + bias = None, + dropout_p = ctx.dropout_p, + is_causual = causal_chunk, + softmax_scale = ctx.softmax_scale + ) + # output should do the transpose back from b,n,l,d to b,l,n,d + block_out = block_out.transpose(1, 2) else: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], - global_k[k_i], - global_v[k_i], - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + if flash_attn_version >= version.parse("2.6.0"): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], + global_k[k_i], + global_v[k_i], + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) @@ -356,41 +376,61 @@ def backward(ctx, grad_output): dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) - if flash_attn_version >= version.parse("2.6.0"): - _flash_attn_backward(d_out, - q_chunk, - k_chunk, - v_chunk, - attn_output_chunk, - lse_chunk, - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + if get_accelerator().device_name() == 'xpu': + # input should do transpose from b,l,n,d to b,n,l,d + dq_this, dk_this, dv_this, _ = flash_attn_func_bwd( + attn_output_chunk.transpose(1, 2), + d_out.transpose(1, 2), + q_chunk.transpose(1, 2), + k_chunk.transpose(1, 2), + v_chunk.transpose(1, 2), + None, + lse_chunk, + None, + None, + dropout_p, + False, + causal_chunk, + softmax_scale + ) + # transpose back + dq_this, dk_this, dv_this = [x.transpose(1, 2) for x in [dq_this, dk_this, dv_this]] else: - _flash_attn_backward(d_out, - q_chunk, - k_chunk, - v_chunk, - attn_output_chunk, - lse_chunk, - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + if flash_attn_version >= version.parse("2.6.0"): + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + else: + _flash_attn_backward(d_out, + q_chunk, + k_chunk, + v_chunk, + attn_output_chunk, + lse_chunk, + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) dq[q_i].add_(dq_this.to(torch.float)) dk[i].add_(dk_this.to(torch.float)) @@ -629,29 +669,43 @@ def forward(ctx: Any, for k_i in range(len(global_k)): causal_chunk = i == k_i with get_accelerator().stream(compute_stream): - if flash_attn_version >= version.parse("2.6.0"): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - softcap=0.0, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + if get_accelerator().device_name() == 'xpu': + # input should do transpose from b,l,n,d to b,n,l,d + block_out, block_lse, _, _ = flash_attn_func_fwd( + global_q[q_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + global_k[kv_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + global_v[kv_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + bias = None, + dropout_p = ctx.dropout_p, + is_causual = causal_chunk, + softmax_scale = ctx.softmax_scale + ) + # output should do the transpose back from b,n,l,d to b,l,n,d + block_out = block_out.transpose(1, 2) else: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - ctx.dropout_p, - ctx.softmax_scale, - causal=causal_chunk, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - return_softmax=False) + if flash_attn_version >= version.parse("2.6.0"): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + softcap=0.0, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + ctx.dropout_p, + ctx.softmax_scale, + causal=causal_chunk, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + return_softmax=False) cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, block_lse) @@ -795,46 +849,66 @@ def backward(ctx, grad_output): causal_chunk = q_i == i - dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) - dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) - dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) - + if get_accelerator().device_name() != 'xpu': + dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) + dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) + dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) with get_accelerator().stream(compute_stream): - if flash_attn_version >= version.parse("2.6.0"): - _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - attn_output[q_compute_chunk_idx].get_gpu_chunk(), - lse[q_compute_chunk_idx].get_gpu_chunk(), - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - softcap=0.0, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + if get_accelerator().device_name() == 'xpu': + # transpose from b,l,n,h to b,n,l,h + dq_this, dk_this, dv_this, _ = flash_attn_func_bwd( + attn_output[q_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + global_q[q_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + global_k[kv_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + global_v[kv_compute_chunk_idx].get_gpu_chunk().transpose(1, 2), + None, + lse[q_compute_chunk_idx].get_gpu_chunk(), + None, + None, + dropout_p, + False, + causal_chunk, + softmax_scale + ) + # transpose back + dq_this, dk_this, dv_this = [x.transpose(1, 2) for x in [dq_this, dk_this, dv_this]] else: - _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), - global_q[q_compute_chunk_idx].get_gpu_chunk(), - global_k[kv_compute_chunk_idx].get_gpu_chunk(), - global_v[kv_compute_chunk_idx].get_gpu_chunk(), - attn_output[q_compute_chunk_idx].get_gpu_chunk(), - lse[q_compute_chunk_idx].get_gpu_chunk(), - dq_this, - dk_this, - dv_this, - dropout_p, - softmax_scale, - causal_chunk, - window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - rng_state=None) + if flash_attn_version >= version.parse("2.6.0"): + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + softcap=0.0, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) + else: + _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), + global_q[q_compute_chunk_idx].get_gpu_chunk(), + global_k[kv_compute_chunk_idx].get_gpu_chunk(), + global_v[kv_compute_chunk_idx].get_gpu_chunk(), + attn_output[q_compute_chunk_idx].get_gpu_chunk(), + lse[q_compute_chunk_idx].get_gpu_chunk(), + dq_this, + dk_this, + dv_this, + dropout_p, + softmax_scale, + causal_chunk, + window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + rng_state=None) if i != (len(global_k) - 1): if q_i != (len(global_q) - 1): diff --git a/op_builder/xpu/flash_attn.py b/op_builder/xpu/flash_attn.py index c8c2674d5d27..6718bc701e3e 100644 --- a/op_builder/xpu/flash_attn.py +++ b/op_builder/xpu/flash_attn.py @@ -26,6 +26,25 @@ def flash_attn_func_v2(self, q, k, v, dropout_p, softmax_scale, is_causal): "Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.") + def flash_attn_fwd(self, q, k, v, bias=None, dropout_p=0.0, is_causual=False, softmax_scale=None): + try: + import torch + import intel_extension_for_pytorch # noqa + return torch.xpu.IpexSDP_forward(q, k, v, bias, dropout_p, is_causual, softmax_scale) + except ImportError: + raise ImportError( + "Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.") + + def flash_attn_bwd(self, out, out_grad, q, k, v, bias, logsumexp, seed, offset, dropout_p, is_bais_grad, is_causal, softmax_scale): + try: + import torch + import intel_extension_for_pytorch # noqa + return torch.xpu.IpexSDP_backward(out, out_grad, q, k, v, bias, logsumexp, seed, offset, dropout_p, is_bais_grad, is_causal, softmax_scale) + except ImportError: + raise ImportError( + "Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.") + + class FlashAttentionBuilder(SYCLOpBuilder): BUILD_VAR = "DS_BUILD_FlashAttention" NAME = "flash_attn"