diff --git a/csrc/layer_norm/ln.h b/csrc/layer_norm/ln.h index 668f1e283..098b50b03 100644 --- a/csrc/layer_norm/ln.h +++ b/csrc/layer_norm/ln.h @@ -59,7 +59,7 @@ struct ParamsBase { // Common data pointers. void *x0; - void *x1; + void *residual; void *x; void *dmask; void *mu; @@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase { , dgamma_part(nullptr) , dcolscale_part(nullptr) , dx0(nullptr) - , dx1(nullptr) + , dresidual(nullptr) , dbeta(nullptr) , dgamma(nullptr) , dcolscale(nullptr) @@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase { // Output: Dgrad. void *dx0; - void *dx1; + void *dresidual; // Output: Wgrad. void *dbeta; void *dgamma; diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp index f911016ef..15ad0d96a 100644 --- a/csrc/layer_norm/ln_api.cpp +++ b/csrc/layer_norm/ln_api.cpp @@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size - c10::optional &x1_, // Residual: BxSxhidden_size + c10::optional &residual_, // Residual: BxSxhidden_size const at::Tensor &gamma, // hidden_size c10::optional &beta_, // hidden_size c10::optional &rowscale_, // BxS @@ -97,8 +97,8 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: bool is_rms_norm=false ) { auto itype = x0.scalar_type(); - auto rtype = x1_.has_value() - ? x1_.value().scalar_type() + auto rtype = residual_.has_value() + ? residual_.value().scalar_type() : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); auto wtype = gamma.scalar_type(); auto otype = itype; @@ -129,11 +129,11 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: TORCH_CHECK(gamma.sizes() == beta.sizes()); } - if (x1_.has_value()) { - auto x1 = x1_.value(); - TORCH_CHECK(x1.is_cuda()) - TORCH_CHECK(x1.is_contiguous()); - TORCH_CHECK(x1.sizes() == sizes); + if (residual_.has_value()) { + auto residual = residual_.value(); + TORCH_CHECK(residual.is_cuda()) + TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.sizes() == sizes); } if (rowscale_.has_value()) { @@ -178,7 +178,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto opts = x0.options(); - bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); + bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); at::Tensor x; if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } at::Tensor dmask; @@ -194,7 +194,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; + launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; @@ -383,8 +383,8 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd auto opts = x.options(); auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); - at::Tensor dx1; - if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } + at::Tensor dresidual; + if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); at::Tensor dcolscale; @@ -397,7 +397,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd launch_params.props = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; + launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; @@ -450,7 +450,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd launcher(launch_params, false); - std::vector result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; + std::vector result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part }; if (colscale_.has_value()) { result.push_back(dcolscale); result.push_back(dcolscale_part); @@ -462,7 +462,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "CUDA DropoutAddLayerNorm"; m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel", - py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"), + py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"), py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"), py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); diff --git a/csrc/layer_norm/ln_bwd_kernels.cuh b/csrc/layer_norm/ln_bwd_kernels.cuh index a728f1d2c..c7261d218 100644 --- a/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/csrc/layer_norm/ln_bwd_kernels.cuh @@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { extern __shared__ char smem_[]; - const bool has_residual = params.dx1 != nullptr; + const bool has_residual = params.dresidual != nullptr; const bool prenorm = params.dx != nullptr; const index_t tidx = threadIdx.x; @@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec dx0; - Rvec dx1; + Rvec dresidual; Ivec x0; if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } #pragma unroll @@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { } else { dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; } - if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; } + if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } if (save_dx0) { compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; if (Is_dropout) { @@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { } } } - if (has_residual) { dx1.store_to(params.dx1, idx_x); } + if (has_residual) { dresidual.store_to(params.dresidual, idx_x); } if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } idx_x += Ktraits::VEC_COLS_PER_LDG; idx_x0 += Ktraits::VEC_COLS_PER_LDG; diff --git a/csrc/layer_norm/ln_fwd_kernels.cuh b/csrc/layer_norm/ln_fwd_kernels.cuh index 03ca60564..f6bccb8c2 100644 --- a/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/csrc/layer_norm/ln_fwd_kernels.cuh @@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) { using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; - const bool has_residual = params.x1 != nullptr; + const bool has_residual = params.residual != nullptr; const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); extern __shared__ char smem_[]; @@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) { for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec x0; - Rvec x1; + Rvec residual; Rvec x; Mvec dmask; if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } - if (has_residual) { x1.load_from(params.x1, idx_x); } + if (has_residual) { residual.load_from(params.residual, idx_x); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use @@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) { compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } - x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; + x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; } else { - x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f; + x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; } if (save_x) { x.data.elt[jt] = x_ij; } xf[it * NUM_ELTS + jt] = x_ij; diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index e37261c69..b85981d78 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -292,7 +292,7 @@ def forward(self, input_ids, position_ids=None, inference_params=None): residual = (dropped + residual) if residual is not None else dropped hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: - # Set prenorm=False here since we don't need to the residual + # Set prenorm=False here since we don't need the residual hidden_states = dropout_add_layer_norm( hidden_states, residual, self.ln_f.weight, self.ln_f.bias, self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, @@ -359,7 +359,7 @@ def load_state_dict(self, state_dict, strict=True): # Previous: Attn / MLP -> Dropout -> Add -> LN # Current: Dropout -> Add -> LN -> Attn / MLP if 'transformer.ln_0.weight' in state_dict: - n_layers = self.config.num_hidden_layers + n_layers = len(self.transformer.layers) ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') state_dict['transformer.ln_f.weight'] = ln_weight diff --git a/flash_attn/ops/layer_norm.py b/flash_attn/ops/layer_norm.py index 8449bf365..8da28e0a2 100644 --- a/flash_attn/ops/layer_norm.py +++ b/flash_attn/ops/layer_norm.py @@ -7,20 +7,20 @@ import dropout_layer_norm -def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, - residual_in_fp32=False, is_rms_norm=False): +def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p, + epsilon, residual_in_fp32=False, is_rms_norm=False): """ Assume that arguments are contiguous """ hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, + x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, 1.0, 0, None, residual_in_fp32, is_rms_norm ) # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma @@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro dropout_p, has_residual, is_rms_norm=False): """ Assume that arguments are contiguous dx == None means that it was a post-norm architecture - (x = drop(x0) + x1 was not returned in the fwd). + (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() @@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro rowscale = rowscale.view(-1) if rowscale is not None else None if colscale is not None: assert x0 is not None, 'x0 is required to compute the gradient of colscale' - dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dropout_p, 1.0, 0, has_residual, is_rms_norm ) - # dx1mat is None if not has_residual + # dresidualmat is None if not has_residual if colscale is None: - return dx0mat, dx1mat, dgamma, dbeta + return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] - return dx0mat, dx1mat, dgamma, dbeta, dcolscale + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale -def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset, - dropout_p, epsilon, rowscale_const, out_numrows, - residual_in_fp32=False, is_rms_norm=False): +def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset, + out_subset, dropout_p, epsilon, rowscale_const, + out_numrows, residual_in_fp32=False, is_rms_norm=False): """ Assume that arguments are contiguous """ hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, + x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm ) # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma @@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga x0_numrows, has_residual, is_rms_norm=False): """ Assume that arguments are contiguous dx == None means that it was a post-norm architecture - (x = drop(x0) + x1 was not returned in the fwd). + (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() @@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga out_subset = out_subset.view(-1) if out_subset is not None else None if colscale is not None: assert x0 is not None, 'x0 is required to compute the gradient of colscale' - dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm ) - # dx1mat is None if not has_residual + # dresidualmat is None if not has_residual if colscale is None: - return dx0mat, dx1mat, dgamma, dbeta + return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] - return dx0mat, dx1mat, dgamma, dbeta, dcolscale + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, + def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): x0 = x0.contiguous() - x1 = x1.contiguous() if x1 is not None else None + residual = residual.contiguous() if residual is not None else None gamma = gamma.contiguous() beta = beta.contiguous() if beta is not None else None rowscale = rowscale.contiguous() if rowscale is not None else None colscale = colscale.contiguous() if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, + x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, is_rms_norm ) # Only need to save x0 if we need to compute gradient wrt colscale @@ -118,7 +118,7 @@ def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale) ctx.prenorm = prenorm ctx.dropout_p = dropout_p - ctx.has_residual = x1 is not None + ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None if not return_dmask: @@ -140,29 +140,29 @@ def backward(ctx, dz, *args): # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, ctx.is_rms_norm ) dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None - return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None, - None, None, None, None) + return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, + None, None, None, None, None) class DropoutAddLayerNormSubsetFn(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): x0 = x0.contiguous() - x1 = x1.contiguous() if x1 is not None else None + residual = residual.contiguous() if residual is not None else None gamma = gamma.contiguous() beta = beta.contiguous() if beta is not None else None colscale = colscale.contiguous() if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, is_rms_norm ) # Only need to save x0 if we need to compute gradient wrt colscale @@ -174,7 +174,7 @@ def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p ctx.dropout_p = dropout_p ctx.rowscale_const = rowscale_const ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = x1 is not None + ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None z_shape = (-1, *x0.shape[1:]) @@ -197,42 +197,42 @@ def backward(ctx, dz, *args): # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm ) dx0 = dx0mat.view(-1, *x.shape[1:]) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None - return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None, - None, None, None, None, None, None, None) + return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, + None, None, None, None, None, None, None, None) def layer_norm(x, weight, bias, epsilon): return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) -def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, - prenorm=False, residual_in_fp32=False, +def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, + layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False): - """residual_in_fp32 only has an effect if x1 is None. - Otherwise residual dtype is x1.dtype. + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( - x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, + x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, False, return_dropout_mask ) -def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None, +def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False): - """residual_in_fp32 only has an effect if x1 is None. - Otherwise residual dtype is x1.dtype. + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( - x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, + x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask ) @@ -254,7 +254,7 @@ def reset_parameters(self): init.ones_(self.weight) init.zeros_(self.bias) - def forward(self, x0, x1=None): - return dropout_add_layer_norm(x0, x1, self.weight, self.bias, + def forward(self, x0, residual=None): + return dropout_add_layer_norm(x0, residual, self.weight, self.bias, self.p if self.training else 0.0, self.epsilon, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) diff --git a/flash_attn/ops/rms_norm.py b/flash_attn/ops/rms_norm.py index a20ecfa92..17ba2939f 100644 --- a/flash_attn/ops/rms_norm.py +++ b/flash_attn/ops/rms_norm.py @@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon): False, True) -def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, - prenorm=False, residual_in_fp32=False, return_dropout_mask=False): - """residual_in_fp32 only has an effect if x1 is None. - Otherwise residual dtype is x1.dtype. +def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, + layerscale=None, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( - x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, + x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, True, return_dropout_mask ) -def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None, +def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False): - """residual_in_fp32 only has an effect if x1 is None. - Otherwise residual dtype is x1.dtype. + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( - x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, + x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask ) @@ -52,7 +53,7 @@ def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32 def reset_parameters(self): init.ones_(self.weight) - def forward(self, x0, x1=None): - return dropout_add_rms_norm(x0, x1, self.weight, None, + def forward(self, x0, residual=None): + return dropout_add_rms_norm(x0, residual, self.weight, None, self.p if self.training else 0.0, self.epsilon, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)