Skip to content

Commit

Permalink
[ViT] Fix extra norm_0, use new LN order in Block
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 16, 2023
1 parent ff34123 commit ef085cf
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
78 changes: 45 additions & 33 deletions flash_attn/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch.nn.functional as F
from torch.nn.init import trunc_normal_

from torchvision.ops import StochasticDepth

from einops import rearrange

from timm.models.helpers import named_apply
Expand Down Expand Up @@ -41,15 +43,18 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
return mlp_cls


def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path,
norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense,
fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False):
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_dense_gelu_dense, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout=drop_rate, drop_path=drop_path,
fused_dropout_add_ln=fused_dropout_add_ln)
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
drop_path1=drop_path1, drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
return block


Expand Down Expand Up @@ -143,32 +148,32 @@ def __init__(
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule

# We change the order of residual and layer norm:
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# self.norm_0 is the first layer norm in the model, while self.norm
# (in the pretrained weight) is the final layer norm.
self.norm_0 = norm_layer(embed_dim)

self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')

self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])

self.dropout = nn.Dropout(p=drop_rate)
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
self.norm = norm_layer(embed_dim)

self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')

# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

Expand Down Expand Up @@ -210,18 +215,8 @@ def forward_features(self, x, all_tokens=True):
cls token.
"""
x = self.patch_embed(x)
x = self._pos_embed(x)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
if not self.fused_dropout_add_ln:
residual = self.pos_drop(x).float()
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
else:
hidden_states, residual = dropout_add_layer_norm(
x, None, self.norm_0.weight, self.norm_0.bias,
self.pos_drop.p if self.training else 0.0, self.norm_0.eps, prenorm=True,
residual_in_fp32=True
)
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
Expand All @@ -232,8 +227,25 @@ def forward_features(self, x, all_tokens=True):
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
hidden_states, _ = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
else:
if self.drop_path.p == 0 or not self.training:
rowscale = None
else:
rowscale = self.drop_path(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.norm.weight, self.norm.bias,
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
prenorm=False, residual_in_fp32=True
)
return hidden_states

def forward_head(self, x, pre_logits: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
"""
if self.prenorm:
if not self.fused_dropout_add_ln:
Expand Down

0 comments on commit ef085cf

Please sign in to comment.