Skip to content

Commit

Permalink
[LayerNorm] Rename x1 -> residual
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 19, 2023
1 parent f68d41e commit eb33e58
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 88 deletions.
6 changes: 3 additions & 3 deletions csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct ParamsBase {

// Common data pointers.
void *x0;
void *x1;
void *residual;
void *x;
void *dmask;
void *mu;
Expand Down Expand Up @@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
, dgamma_part(nullptr)
, dcolscale_part(nullptr)
, dx0(nullptr)
, dx1(nullptr)
, dresidual(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
, dcolscale(nullptr)
Expand All @@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {

// Output: Dgrad.
void *dx0;
void *dx1;
void *dresidual;
// Output: Wgrad.
void *dbeta;
void *dgamma;
Expand Down
30 changes: 15 additions & 15 deletions csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &beta_, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
Expand All @@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
bool is_rms_norm=false
) {
auto itype = x0.scalar_type();
auto rtype = x1_.has_value()
? x1_.value().scalar_type()
auto rtype = residual_.has_value()
? residual_.value().scalar_type()
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
auto wtype = gamma.scalar_type();
auto otype = itype;
Expand Down Expand Up @@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(gamma.sizes() == beta.sizes());
}

if (x1_.has_value()) {
auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda())
TORCH_CHECK(x1.is_contiguous());
TORCH_CHECK(x1.sizes() == sizes);
if (residual_.has_value()) {
auto residual = residual_.value();
TORCH_CHECK(residual.is_cuda())
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.sizes() == sizes);
}

if (rowscale_.has_value()) {
Expand Down Expand Up @@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:

auto opts = x0.options();

bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask;
Expand All @@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto opts = x.options();

auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
at::Tensor dresidual;
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
at::Tensor dcolscale;
Expand All @@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd

launcher(launch_params, false);

std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
if (colscale_.has_value()) {
result.push_back(dcolscale);
result.push_back(dcolscale_part);
Expand All @@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
Expand Down
8 changes: 4 additions & 4 deletions csrc/layer_norm/ln_bwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {

extern __shared__ char smem_[];

const bool has_residual = params.dx1 != nullptr;
const bool has_residual = params.dresidual != nullptr;
const bool prenorm = params.dx != nullptr;

const index_t tidx = threadIdx.x;
Expand Down Expand Up @@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0;
Rvec dx1;
Rvec dresidual;
Ivec x0;
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
#pragma unroll
Expand All @@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
}
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
if (save_dx0) {
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) {
Expand All @@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
}
if (has_residual) { dx1.store_to(params.dx1, idx_x); }
if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
Expand Down
10 changes: 5 additions & 5 deletions csrc/layer_norm/ln_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;

const bool has_residual = params.x1 != nullptr;
const bool has_residual = params.residual != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);

extern __shared__ char smem_[];
Expand Down Expand Up @@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec x0;
Rvec x1;
Rvec residual;
Rvec x;
Mvec dmask;
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { x1.load_from(params.x1, idx_x); }
if (has_residual) { residual.load_from(params.residual, idx_x); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
Expand All @@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def forward(self, input_ids, position_ids=None, inference_params=None):
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need to the residual
# Set prenorm=False here since we don't need the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
Expand Down Expand Up @@ -359,7 +359,7 @@ def load_state_dict(self, state_dict, strict=True):
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict:
n_layers = self.config.num_hidden_layers
n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight
Expand Down
Loading

0 comments on commit eb33e58

Please sign in to comment.