Skip to content

Commit

Permalink
[ViT] Support timm checkpoint, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 16, 2023
1 parent 2ec7d3f commit 780e8ee
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 11 deletions.
33 changes: 29 additions & 4 deletions flash_attn/models/vit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
import re
from functools import partial
from copy import deepcopy

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -218,17 +221,16 @@ def forward_features(self, x, all_tokens=True):
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
# if True:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
for block in self.blocks[:-1]:
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
hidden_states, residual = self.blocks[-1](hidden_states, residual,
mixer_subset=slice(0, 1))
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
Expand Down Expand Up @@ -258,6 +260,29 @@ def forward(self, x):
x = self.forward_head(x)
return x

def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
return super().load_state_dict(state_dict, strict=strict)


def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
Expand Down
13 changes: 10 additions & 3 deletions flash_attn/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
p._shared_params = True

def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None):
mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if self.prenorm:
if not self.fused_dropout_add_ln:
Expand All @@ -116,8 +119,12 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
hidden_states = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if mixer_kwargs is None:
mixer_kwargs = {}
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
Expand Down
10 changes: 7 additions & 3 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _update_kv_cache(self, kv, inference_params):
return _update_kv_cache(kv, inference_params, self.layer_idx)

def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
inference_params=None, **kwargs):
mixer_subset=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
Expand All @@ -433,6 +433,9 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
Expand All @@ -454,6 +457,7 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
qkv = self.Wqkv(x)
else:
Expand Down Expand Up @@ -491,14 +495,14 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
context = rearrange(context, 'b h d -> b 1 h d')
else:
if not self.return_residual:
q = self.Wq(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
if self.dwconv:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without any optimizations enabled) matches the
"""Check that our implementation of OPT (without all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
Expand Down
49 changes: 49 additions & 0 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import re

import torch
import pytest

from timm.models.vision_transformer import vit_base_patch16_224

from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224


@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'

kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)

model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)

model.load_state_dict(model_ref.state_dict())

model.eval()
model_ref.eval()
model_timm.eval()

torch.manual_seed(0)
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
out = model(x)
out_timm = model_timm(x)
out_ref = model_ref(x.float())

print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()

0 comments on commit 780e8ee

Please sign in to comment.