diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index cc2b54160..6dd30f01e 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head softmax_scale, is_causal); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] @@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head if (out_.has_value()) { out_.value().copy_(out); } } - return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } std::vector @@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_scale, is_causal); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] @@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q if (out_.has_value()) { out_.value().copy_(out); } } - return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { @@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - c10::optional gen_) { + c10::optional gen_, + c10::optional &rng_state) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - if (is_dropout) { + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); } launch(params, stream, /*configure=*/false); @@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, - c10::optional gen_ + c10::optional gen_, + c10::optional &rng_state ) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - if (is_dropout) { + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); } launch(params, stream, /*configure=*/false); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index cb0a57dff..e65d7d536 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -91,6 +91,9 @@ struct Flash_fwd_params : public Qkv_params { // Random state. at::PhiloxCudaState philox_args; + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + bool is_bf16; bool is_causal; }; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 322551b23..f7d0965b1 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -755,9 +755,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view); } - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; clear(acc_dv); clear(acc_dk); @@ -1301,9 +1300,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; clear(acc_dq); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 2eba4ef12..6e7364776 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -130,6 +130,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // The thread index. const int tidx = threadIdx.x; + // The global block index. + const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; @@ -308,6 +310,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi unsigned long long seed = std::get<0>(seeds); unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + // Save seed and offset for backward. + if (block_id == 0 && tidx == 0) { + params.rng_state[0] = seed; + params.rng_state[1] = std::get<1>(seeds); + } + clear(acc_o); // For performance reason, we separate out two kinds of iterations: diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 19796a07c..aade2695c 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -39,45 +39,46 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None ) - return out, q, k, v, out_padded, softmax_lse, S_dmask + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, return_softmax, None ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() - return out, q, k, v, out_padded, softmax_lse, S_dmask + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, - dropout_p, softmax_scale, causal): + dropout_p, softmax_scale, causal, rng_state=None): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( - dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None + dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, + softmax_scale, causal, None, rng_state ) return dq, dk, dv, softmax_d def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal): + dropout_p, softmax_scale, causal, rng_state=None): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -88,11 +89,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -105,18 +104,13 @@ def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], - ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state ) dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None @@ -124,11 +118,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -142,19 +134,14 @@ def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, - ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state ) dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None, None, None @@ -162,11 +149,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -179,20 +164,16 @@ def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) _flash_attn_backward( dout, q, k, v, out, softmax_lse, - dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal + dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, + rng_state=rng_state ) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., :dout.shape[-1]] - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dkv, None, None, None, None @@ -201,11 +182,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -221,21 +200,16 @@ def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state ) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., :dout.shape[-1]] - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dkv, None, None, None, None, None, None, None, None @@ -243,11 +217,9 @@ class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -260,19 +232,15 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, - dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal + dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + rng_state=rng_state ) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dk = dk[..., :dout.shape[-1]] dv = dv[..., :dout.shape[-1]] - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dk, dv, None, None, None, None, None, None, None, None @@ -281,11 +249,9 @@ class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) @@ -301,19 +267,15 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + rng_state=rng_state ) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dk = dk[..., :dout.shape[-1]] dv = dv[..., :dout.shape[-1]] - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dk, dv, None, None, None, None, None, None, None, None