Skip to content

Commit

Permalink
[LayerNorm] Support DropPath
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Oct 10, 2022
1 parent 747f905 commit fee1286
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 228 deletions.
1 change: 1 addition & 0 deletions csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct ParamsBase {
void *mu;
void *rs;
void *gamma;
void *rowscale;

float dropout_keep_p;
float dropout_scale;
Expand Down
36 changes: 33 additions & 3 deletions csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const float epsilon,
c10::optional<at::Generator> gen_,
Expand All @@ -107,16 +108,24 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto sizes = x0.sizes();
TORCH_CHECK(sizes.size() == 2);

const int rows = sizes[0];
const int cols = sizes[1];
auto hidden_size = gamma.numel();

if (x1_.has_value()) {
auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda())
TORCH_CHECK(x1.is_contiguous());
TORCH_CHECK(x1.sizes() == sizes);
}

const int rows = sizes[0];
const int cols = sizes[1];
auto hidden_size = gamma.numel();
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}

TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);
Expand All @@ -142,6 +151,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
Expand Down Expand Up @@ -203,6 +213,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
Expand Down Expand Up @@ -243,6 +254,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(dmask.sizes() == sizes);
}

if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}

auto hidden_size = gamma.numel();

TORCH_CHECK(mu.numel() == rows);
Expand All @@ -264,6 +283,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;

auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);

Expand Down Expand Up @@ -311,6 +331,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
Expand Down Expand Up @@ -355,6 +376,14 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
TORCH_CHECK(dmask.sizes() == sizes);
}

if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}

auto hidden_size = gamma.numel();

TORCH_CHECK(mu.numel() == rows);
Expand All @@ -376,6 +405,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;

// TODO: how to set template param for launcher
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
Expand Down
9 changes: 6 additions & 3 deletions csrc/layer_norm/ln_bwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace layer_norm {

template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual>
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {

Expand All @@ -17,6 +17,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };

using input_t = typename Ktraits::input_t;
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using mask_t = typename Ktraits::mask_t;
Expand Down Expand Up @@ -73,6 +74,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f;
Rvec x[LDGS];
Mvec dmask[LDGS];
Ovec dz[LDGS];
Expand Down Expand Up @@ -129,10 +131,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
if (Has_residual) { dx1[it].data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res;
if (Is_dropout) {
dx0[it].data.elt[jt] = dmask[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f;
dx0[it].data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
} else {
dx0[it].data.elt[jt] = dx_tmp_res;
dx0[it].data.elt[jt] = dx0_tmp_res;
}
}
if (Has_residual) { dx1[it].store_to(params.dx1, idx); }
Expand Down
109 changes: 56 additions & 53 deletions csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include "static_switch.h"

using namespace layer_norm;

Expand Down Expand Up @@ -35,59 +36,61 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
>;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
auto kernel = prenorm
? (is_dropout
? (has_residual ? &ln_bwd_kernel<Kernel_traits, true, true, true> : &ln_bwd_kernel<Kernel_traits, true, true, false>)
: (has_residual ? &ln_bwd_kernel<Kernel_traits, true, false, true> : &ln_bwd_kernel<Kernel_traits, true, false, false>))
: (is_dropout
? (has_residual ? &ln_bwd_kernel<Kernel_traits, false, true, true> : &ln_bwd_kernel<Kernel_traits, false, true, false>)
: (has_residual ? &ln_bwd_kernel<Kernel_traits, false, false, true> : &ln_bwd_kernel<Kernel_traits, false, false, false>));

if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}

if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;

if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}

using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;

auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
bool has_rowscale = launch_params.params.rowscale != nullptr;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}

if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;

if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}

using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;

auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});
});
}

// Create backward launch function and register. Macro signature:
Expand Down
Loading

0 comments on commit fee1286

Please sign in to comment.