diff --git a/csrc/layer_norm/ln.h b/csrc/layer_norm/ln.h index ab66aa0..58bbffa 100644 --- a/csrc/layer_norm/ln.h +++ b/csrc/layer_norm/ln.h @@ -62,6 +62,7 @@ struct ParamsBase { void *mu; void *rs; void *gamma; + void *rowscale; float dropout_keep_p; float dropout_scale; diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp index c0abcfd..0d8667c 100644 --- a/csrc/layer_norm/ln_api.cpp +++ b/csrc/layer_norm/ln_api.cpp @@ -83,6 +83,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: c10::optional &x1_, // Residual: BxSxhidden_size const at::Tensor &gamma, // hidden_size const at::Tensor &beta, // hidden_size + c10::optional &rowscale_, // BxS const float dropout_p, const float epsilon, c10::optional gen_, @@ -107,6 +108,10 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto sizes = x0.sizes(); TORCH_CHECK(sizes.size() == 2); + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma.numel(); + if (x1_.has_value()) { auto x1 = x1_.value(); TORCH_CHECK(x1.is_cuda()) @@ -114,9 +119,13 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: TORCH_CHECK(x1.sizes() == sizes); } - const int rows = sizes[0]; - const int cols = sizes[1]; - auto hidden_size = gamma.numel(); + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(hidden_size == cols); @@ -142,6 +151,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -203,6 +213,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd const at::Tensor &mu, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size + c10::optional &rowscale_, // BxS const float dropout_p, const bool has_residual ) { @@ -243,6 +254,14 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd TORCH_CHECK(dmask.sizes() == sizes); } + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } + auto hidden_size = gamma.numel(); TORCH_CHECK(mu.numel() == rows); @@ -264,6 +283,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); @@ -311,6 +331,7 @@ std::vector dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // const at::Tensor &mu, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size + c10::optional &rowscale_, // BxS const float dropout_p, const bool has_residual ) { @@ -355,6 +376,14 @@ std::vector dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // TORCH_CHECK(dmask.sizes() == sizes); } + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } + auto hidden_size = gamma.numel(); TORCH_CHECK(mu.numel() == rows); @@ -376,6 +405,7 @@ std::vector dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; // TODO: how to set template param for launcher auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); diff --git a/csrc/layer_norm/ln_bwd_kernels.cuh b/csrc/layer_norm/ln_bwd_kernels.cuh index 68bc975..0d1f047 100644 --- a/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/csrc/layer_norm/ln_bwd_kernels.cuh @@ -2,7 +2,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { @@ -17,6 +17,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + using input_t = typename Ktraits::input_t; using compute_t = typename Ktraits::compute_t; using index_t = typename Ktraits::index_t; using mask_t = typename Ktraits::mask_t; @@ -73,6 +74,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { const compute_t mu_r = static_cast(params.mu)[row]; const compute_t rs_r = static_cast(params.rs)[row]; + const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast(params.rowscale)[row]) : 1.0f; Rvec x[LDGS]; Mvec dmask[LDGS]; Ovec dz[LDGS]; @@ -129,10 +131,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; if (Has_residual) { dx1[it].data.elt[jt] = dx_tmp_res; } + compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res; if (Is_dropout) { - dx0[it].data.elt[jt] = dmask[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; + dx0[it].data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f; } else { - dx0[it].data.elt[jt] = dx_tmp_res; + dx0[it].data.elt[jt] = dx0_tmp_res; } } if (Has_residual) { dx1[it].store_to(params.dx1, idx); } diff --git a/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index c2e15c0..8a080e9 100644 --- a/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -2,6 +2,7 @@ #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "ln_bwd_kernels.cuh" +#include "static_switch.h" using namespace layer_norm; @@ -35,59 +36,61 @@ void launch_(LaunchParams &launch_params, const bool configure_params >; bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool has_residual = launch_params.params.dx1 != nullptr; - auto kernel = prenorm - ? (is_dropout - ? (has_residual ? &ln_bwd_kernel : &ln_bwd_kernel) - : (has_residual ? &ln_bwd_kernel : &ln_bwd_kernel)) - : (is_dropout - ? (has_residual ? &ln_bwd_kernel : &ln_bwd_kernel) - : (has_residual ? &ln_bwd_kernel : &ln_bwd_kernel)); - - if( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); + bool has_rowscale = launch_params.params.rowscale != nullptr; + BOOL_SWITCH(prenorm, PrenormConst, [&] { + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(has_residual, HasResidualConst, [&] { + BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] { + auto kernel = &ln_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } + + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); + }); + }); + }); + }); } // Create backward launch function and register. Macro signature: diff --git a/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/csrc/layer_norm/ln_fwd_cuda_kernel.cu index dae368d..3653c46 100644 --- a/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -2,6 +2,7 @@ #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "ln_fwd_kernels.cuh" +#include "static_switch.h" using namespace layer_norm; @@ -33,45 +34,48 @@ void launch_(LaunchParams &launch_params, const bool configure_params BYTES_PER_LDG >; bool has_residual = launch_params.params.x1 != nullptr; - auto kernel = launch_params.params.dropout_keep_p < 1.f - ? (has_residual ? &ln_fwd_kernel : &ln_fwd_kernel) - : (has_residual ? &ln_fwd_kernel : &ln_fwd_kernel); - - if( configure_params ) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; - launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } - + bool has_rowscale = launch_params.params.rowscale != nullptr; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(has_residual, HasResidualConst, [&] { + BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); } // Create forward launch function and register. Macro signature: diff --git a/csrc/layer_norm/ln_fwd_kernels.cuh b/csrc/layer_norm/ln_fwd_kernels.cuh index d4c377b..776b5d6 100644 --- a/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/csrc/layer_norm/ln_fwd_kernels.cuh @@ -13,7 +13,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { @@ -63,6 +63,8 @@ void ln_fwd_kernel(FwdParams params) { compute_t *mu_ptr = static_cast(params.mu); compute_t *rs_ptr = static_cast(params.rs); + const input_t *rowscale = static_cast(params.rowscale); + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu curandStatePhilox4_32_10_t state; if (Is_dropout) { @@ -84,6 +86,7 @@ void ln_fwd_kernel(FwdParams params) { constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f; Ivec x0[LDGS]; Rvec x1[LDGS]; Rvec x[LDGS]; @@ -103,7 +106,7 @@ void ln_fwd_kernel(FwdParams params) { float rand = curand_uniform(&state); keep = mask_t(rand <= params.dropout_keep_p); } - compute_t x0_ij = compute_t(x0[it].data.elt[jt]); + compute_t x0_ij = Has_rowscale ? compute_t(x0[it].data.elt[jt]) * rowscale_val : compute_t(x0[it].data.elt[jt]); compute_t x_ij; if (Has_residual) { compute_t x1_ij = compute_t(x1[it].data.elt[jt]); diff --git a/csrc/layer_norm/setup.py b/csrc/layer_norm/setup.py index cc904d3..c8f5fea 100644 --- a/csrc/layer_norm/setup.py +++ b/csrc/layer_norm/setup.py @@ -98,7 +98,10 @@ def append_nvcc_threads(nvcc_extra_args): raise_if_cuda_home_none("--fast_layer_norm") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] -cc_flag.append("-arch=compute_70") +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") ext_modules.append( CUDAExtension( @@ -113,8 +116,6 @@ def append_nvcc_threads(nvcc_extra_args): "nvcc": append_nvcc_threads( [ "-O3", - "-gencode", - "arch=compute_70,code=sm_70", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", diff --git a/csrc/layer_norm/static_switch.h b/csrc/layer_norm/static_switch.h new file mode 100644 index 0000000..7920ac0 --- /dev/null +++ b/csrc/layer_norm/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/src/models/modules/block.py b/src/models/modules/block.py index 0e24c8c..077ead3 100644 --- a/src/models/modules/block.py +++ b/src/models/modules/block.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch import Tensor +from torchvision.ops import StochasticDepth + from src.models.modules.mha import MHA from src.models.modules.mlp import Mlp @@ -20,7 +22,7 @@ class Block(nn.Module): def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., + dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., fused_dropout_add_ln=False): super().__init__() self.prenorm = prenorm @@ -30,12 +32,14 @@ def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, if mlp_cls is None: mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) - self.norm1 = norm_cls(dim) self.dropout1 = dropout_cls(resid_dropout) + self.drop_path1 = StochasticDepth(drop_path, mode='row') + self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): - self.norm2 = norm_cls(dim) self.dropout2 = dropout_cls(resid_dropout) + self.drop_path2 = StochasticDepth(drop_path, mode='row') + self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' @@ -52,43 +56,71 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None): assert residual is not None mixer_out = self.mixer(hidden_states) if not self.fused_dropout_add_ln: - residual = self.dropout1(mixer_out) + residual + residual = self.drop_path1(self.dropout1(mixer_out)) + residual hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + ) hidden_states, residual = dropout_add_layer_norm( mixer_out, residual, self.norm1.weight, self.norm1.bias, - self.dropout1.p if self.training else 0.0, self.norm1.eps, prenorm=True + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=True ) if not isinstance(self.mlp, nn.Identity): mlp_out = self.mlp(hidden_states) if not self.fused_dropout_add_ln: - residual = self.dropout2(mlp_out) + residual + residual = self.drop_path2(self.dropout2(mlp_out)) + residual hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + ) hidden_states, residual = dropout_add_layer_norm( mlp_out, residual, self.norm2.weight, self.norm2.bias, - self.dropout2.p if self.training else 0.0, self.norm2.eps, prenorm=True + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=True ) return hidden_states, residual else: assert residual is None mixer_out = self.mixer(hidden_states) if not self.fused_dropout_add_ln: - hidden_states = self.norm1((self.dropout1(mixer_out) + hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(dtype=self.norm1.weight.dtype)) else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + ) hidden_states = dropout_add_layer_norm( mixer_out, hidden_states, self.norm1.weight, self.norm1.bias, - self.dropout1.p if self.training else 0.0, self.norm1.eps, prenorm=False + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=False ) if not isinstance(self.mlp, nn.Identity): mlp_out = self.mlp(hidden_states) if not self.fused_dropout_add_ln: - hidden_states = self.norm2((self.drop_path(self.dropout2(mlp_out)) + hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(dtype=self.norm2.weight.dtype)) else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + ) hidden_states = dropout_add_layer_norm( mlp_out, hidden_states, self.norm2.weight, self.norm2.bias, - self.dropout2.p if self.training else 0.0, self.norm2.eps, prenorm=False + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=False ) return hidden_states diff --git a/src/ops/layer_norm.py b/src/ops/layer_norm.py index b8475c3..b3fba77 100644 --- a/src/ops/layer_norm.py +++ b/src/ops/layer_norm.py @@ -19,192 +19,142 @@ def forward(self, x): return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) -def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32): +def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon, + residual_in_fp32): """ Assume that arguments are contiguous """ hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, x1mat, gamma, beta, dropout_p, epsilon, None, residual_in_fp32 + x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32 ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma -def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, dropout_p, has_residual): +def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, + has_residual): """ Assume that arguments are contiguous """ # dmask is None if dropout_p == 0.0 hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) + rowscale = rowscale.view(-1) if rowscale is not None else None dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, xmat, dmask, mu, rsigma, gamma, dropout_p, has_residual + dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual ) # dx1mat is None if not has_residual return dx0mat, dx1mat, dgamma, dbeta -def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, dropout_p, - has_residual): +def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale, + dropout_p, has_residual): """ Assume that arguments are contiguous """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) dxmat = dx.view(xmat.shape) + rowscale = rowscale.view(-1) if rowscale is not None else None dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd( - dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, dropout_p, has_residual + dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual ) return dx0mat, dx1mat, dgamma, dbeta class DropoutAddLayerNormFN(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32): - x0 = x0.contiguous() - x1 = x1.contiguous() - gamma = gamma.contiguous() - beta = beta.contiguous() - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32 - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma) - ctx.dropout_p = dropout_p - ctx.has_residual = x1 is not None - return zmat.view(x0.shape) - - @staticmethod - def backward(ctx, dz): - # assert dz.is_contiguous() - dz = dz.contiguous() # this happens! - x, dmask, gamma, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward( - dz, x, dmask, mu, rsigma, gamma, dropout_p, has_residual - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None - - -# We duplicate code to return both the output and the dropout mask for testing. -# Returning both makes backward a bit slower, so we want to keep using the other version for speed. -class DropoutAddLayerNormDmaskFN(torch.autograd.Function): - @staticmethod - def forward(ctx, x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32): + def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dmask=False): x0 = x0.contiguous() x1 = x1.contiguous() if x1 is not None else None gamma = gamma.contiguous() beta = beta.contiguous() + rowscale = rowscale.contiguous() if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32 + x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma) + ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) ctx.dropout_p = dropout_p ctx.has_residual = x1 is not None - dmask = dmask.view(x0.shape) if dropout_p > 0. else torch.ones(x0.shape, dtype=torch.uint8, - device=x0.device) - ctx.mark_non_differentiable(dmask) - return zmat.view(x0.shape), dmask + if not return_dmask: + return zmat.view(x0.shape) + else: + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return zmat.view(x0.shape), dmask @staticmethod - def backward(ctx, dz, ddmask_ignored_): + def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = dz.contiguous() # this happens! - x, dmask, gamma, mu, rsigma = ctx.saved_tensors + x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward( - dz, x, dmask, mu, rsigma, gamma, dropout_p, has_residual + dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual ) dx0 = dx0mat.view(x.shape) dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None + return dx0, dx1, dgamma, dbeta, None, None, None, None, None class DropoutAddLayerNormPrenormFN(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32): - x0 = x0.contiguous() - x1 = x1.contiguous() if x1 is not None else None - gamma = gamma.contiguous() - beta = beta.contiguous() - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32 - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma) - ctx.dropout_p = dropout_p - ctx.has_residual = x1 is not None - return zmat.view(x0.shape), xmat.view(x0.shape) - - @staticmethod - def backward(ctx, dz, dx): - # assert dz.is_contiguous() - dz = dz.contiguous() # this happens! - dx = dx.contiguous() # this happens! - x, dmask, gamma, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward( - dz, dx, x, dmask, mu, rsigma, gamma, dropout_p, has_residual - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None - - -# We duplicate code to return both the output and the dropout mask for testing. -# Returning both makes backward a bit slower, so we want to keep using the other version for speed. -class DropoutAddLayerNormPrenormDmaskFN(torch.autograd.Function): - @staticmethod - def forward(ctx, x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32): + def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dmask=False): x0 = x0.contiguous() x1 = x1.contiguous() if x1 is not None else None gamma = gamma.contiguous() beta = beta.contiguous() + rowscale = rowscale.contiguous() if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, dropout_p, epsilon, residual_in_fp32 + x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma) + ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) ctx.dropout_p = dropout_p ctx.has_residual = x1 is not None - dmask = dmask.view(x0.shape) if dropout_p > 0. else torch.ones(x0.shape, dtype=torch.uint8, - device=x0.device) - ctx.mark_non_differentiable(dmask) - return zmat.view(x0.shape), xmat.view(x0.shape), dmask + if not return_dmask: + return zmat.view(x0.shape), xmat.view(x0.shape) + else: + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return zmat.view(x0.shape), xmat.view(x0.shape), dmask @staticmethod - def backward(ctx, dz, dx, ddmask_ignored_): + def backward(ctx, dz, dx, *args): # assert dz.is_contiguous() dz = dz.contiguous() # this happens! dx = dx.contiguous() # this happens! - x, dmask, gamma, mu, rsigma = ctx.saved_tensors + x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward( - dz, dx, x, dmask, mu, rsigma, gamma, dropout_p, has_residual + dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual ) dx0 = dx0mat.view(x.shape) dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None + return dx0, dx1, dgamma, dbeta, None, None, None, None, None -def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, prenorm=False, - residual_in_fp32=False, return_dropout_mask=False): +def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, + prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): """residual_in_fp32 only has an effect if x1 is None. Otherwise residual dtype is x1.dtype. """ - args = (x0, x1, weight, bias, dropout_p, epsilon, residual_in_fp32) + args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dropout_mask) if not prenorm: - return (DropoutAddLayerNormFN.apply(*args) if not return_dropout_mask - else DropoutAddLayerNormDmaskFN.apply(*args)) + return DropoutAddLayerNormFN.apply(*args) else: - return (DropoutAddLayerNormPrenormFN.apply(*args) if not return_dropout_mask - else DropoutAddLayerNormPrenormDmaskFN.apply(*args)) + return DropoutAddLayerNormPrenormFN.apply(*args) class DropoutAddLayerNorm(torch.nn.Module): diff --git a/tests/ops/test_dropout_layer_norm.py b/tests/ops/test_dropout_layer_norm.py index 7edf37c..8c73d61 100644 --- a/tests/ops/test_dropout_layer_norm.py +++ b/tests/ops/test_dropout_layer_norm.py @@ -11,9 +11,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 +@pytest.mark.parametrize('has_rowscale', [True, False]) +# @pytest.mark.parametrize('has_rowscale', [True]) @pytest.mark.parametrize('has_residual', [True, False]) +# @pytest.mark.parametrize('has_residual', [False]) @pytest.mark.parametrize('dropout_p', [0.37, 0.0]) -# @pytest.mark.parametrize('dropout_p', [0.37]) +# @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) @pytest.mark.parametrize('input_dtype,residual_dtype', @@ -24,9 +27,12 @@ @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1600, 2048]) # @pytest.mark.parametrize('hidden_size', [768]) def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, - dropout_p, has_residual): + dropout_p, has_residual, has_rowscale): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported + # Backward numerical error is high, and this case isn't used + if has_rowscale and not has_residual: + pytest.skip() device = 'cuda' # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) @@ -44,6 +50,16 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w x1_ref = x1_pt.detach().clone().float().requires_grad_() else: x1 = None + if has_rowscale: + rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) + survival_rate = 0.87 + rowscale = rowscale.bernoulli_(survival_rate) / survival_rate + x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') + x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') + else: + rowscale = None + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) @@ -55,16 +71,17 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w model_ref.weight.copy_(model_pt.weight) model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 - out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, model.epsilon, + out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, + model.epsilon, rowscale=rowscale, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) assert out.dtype == input_dtype print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') if has_residual: - residual_pt = ((x0_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) - residual_ref = (x0_ref * dmask.float()) / (1 - dropout_p) + x1_ref + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref else: - residual_pt = ((x0_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) - residual_ref = (x0_ref * dmask.float()) / (1 - dropout_p) + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_ref = model_ref(residual_ref) assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 @@ -123,6 +140,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 +@pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @@ -132,9 +150,12 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1600, 2048]) def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, - dropout_p, has_residual): + dropout_p, has_residual, has_rowscale): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported + # Backward numerical error is high, and this case isn't used + if has_rowscale and not has_residual: + pytest.skip() device = 'cuda' # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 2e-4) @@ -152,6 +173,16 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ x1_ref = x1_pt.detach().clone().float().requires_grad_() else: x1 = None + if has_rowscale: + rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) + survival_rate = 0.87 + rowscale = rowscale.bernoulli_(survival_rate) / survival_rate + x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') + x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') + else: + rowscale = None + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, @@ -163,16 +194,16 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, - model.epsilon, prenorm=True, + model.epsilon, rowscale=rowscale, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') if has_residual: - residual_pt = ((x0_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) - residual_ref = (x0_ref * dmask.float()) / (1 - dropout_p) + x1_ref + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref else: - residual_pt = ((x0_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) - residual_ref = (x0_ref * dmask.float()) / (1 - dropout_p) + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_ref = model_ref(residual_ref) assert out.dtype == input_dtype