Skip to content

Commit

Permalink
Enable CUDA graphs (Dao-AILab#386)
Browse files Browse the repository at this point in the history
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Jul 27, 2023
1 parent 4c98d0b commit a03f6f8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 82 deletions.
50 changes: 36 additions & 14 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_scale,
is_causal);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
Expand All @@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_scale,
is_causal);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
Expand All @@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
Expand Down Expand Up @@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
c10::optional<at::Generator> gen_) {
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Expand Down Expand Up @@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;

if (is_dropout) {
if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}

launch(params, stream, /*configure=*/false);
Expand Down Expand Up @@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
c10::optional<at::Generator> gen_
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;

if (is_dropout) {
if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}

launch(params, stream, /*configure=*/false);
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ struct Flash_fwd_params : public Qkv_params {
// Random state.
at::PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_causal;
};
Expand Down
10 changes: 4 additions & 6 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view);
}

auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
auto seed = params.rng_state[0];
auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;

clear(acc_dv);
clear(acc_dk);
Expand Down Expand Up @@ -1301,9 +1300,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
#pragma unroll
for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); }

auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
auto seed = params.rng_state[0];
auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;

clear(acc_dq);

Expand Down
8 changes: 8 additions & 0 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// The thread index.
const int tidx = threadIdx.x;
// The global block index.
const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;

constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
Expand Down Expand Up @@ -308,6 +310,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;

// Save seed and offset for backward.
if (block_id == 0 && tidx == 0) {
params.rng_state[0] = seed;
params.rng_state[1] = std::get<1>(seeds);
}

clear(acc_o);

// For performance reason, we separate out two kinds of iterations:
Expand Down
Loading

0 comments on commit a03f6f8

Please sign in to comment.