Skip to content

Commit

Permalink
Add ViT model and configs
Browse files Browse the repository at this point in the history
tridao committed Oct 10, 2022
1 parent fee1286 commit 1417543
Showing 6 changed files with 358 additions and 2 deletions.
73 changes: 73 additions & 0 deletions configs/experiment/imagenet/deit/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# @package _global_
defaults:
- override /trainer: default # choose trainer from 'configs/trainer/'
- override /model: null
- override /datamodule: imagenet
- override /optimizer: adamw
- override /scheduler: null
- override /callbacks: default
- override /metrics: [acc, acctop5]
- override /logger: wandb

seed: 1111

trainer:
accelerator: gpu
devices: 8
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${.devices} * ${datamodule.batch_size} * ${.num_nodes}}}
max_epochs: 300
precision: 16
replace_sampler_ddp: ${eval:"${datamodule.num_aug_repeats} == 0"}

datamodule:
batch_size: 128 # Per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
image_size: 224
train_transforms:
_target_: timm.data.create_transform
input_size: ${datamodule.image_size}
is_training: True
auto_augment: rand-m9-mstd0.5-inc1 # Use AutoAugment policy
interpolation: bicubic
re_prob: 0.25 # Random erase prob
re_mode: pixel # Random erase mode
val_transforms:
_target_: timm.data.create_transform
input_size: ${datamodule.image_size}
interpolation: bicubic
crop_pct: 0.9
test_transforms: ${.val_transforms}
mixup:
_target_: src.datamodules.timm_mixup.TimmMixup
mixup_alpha: 0.8
cutmix_alpha: 1.0
label_smoothing: 0.0 # We're using label smoothing from Pytorch's CrossEntropyLoss
# DeiT paper says they use RepeatedAug, but for ViT-S I get 79.7% with RepeatedAug and
# 80.1% without.
num_aug_repeats: 0
num_workers: 11 # For A100 we need a lot of workers, V100 night need fewer

train:
global_batch_size: 1024
num_steps_per_epoch: ${div_up:${datamodule.__train_len}, ${train.global_batch_size}}
optimizer:
lr: ${eval:5e-4 * ${train.global_batch_size} / 512}
weight_decay: 0.05
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
scheduler:
_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
t_initial: ${eval:${trainer.max_epochs} * ${train.num_steps_per_epoch}}
lr_min: 1e-5
warmup_lr_init: 1e-6
warmup_t: ${eval:5 * ${train.num_steps_per_epoch}}
cycle_limit: 1
t_in_epochs: False
scheduler_interval: step
loss_fn:
_target_: torch.nn.CrossEntropyLoss
label_smoothing: 0.1
loss_fn_val:
_target_: torch.nn.CrossEntropyLoss
17 changes: 17 additions & 0 deletions configs/experiment/imagenet/deit/deit-b-flash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
defaults:
- /experiment/imagenet/deit/deit-b.yaml
- override /optimizer: adamw-apex
- override /callbacks: [default, ema]

model:
_target_: src.models.vit.vit.vit_base_patch16_224
drop_path_rate: 0.1
use_flash_attn: True
fused_bias_fc: True
fused_dense_gelu_dense: True
fused_dropout_add_ln: True
bf16: ${eval:'"${trainer.precision}" == "bf16"'}

# trainer:
# strategy: deepspeed_stage_1
19 changes: 19 additions & 0 deletions configs/experiment/imagenet/deit/deit-b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
defaults:
- /experiment/imagenet/deit/base.yaml
# TD [2022-05-27]: the DeiT paper says they don't use EMA but I'm only able to
# replicate their numbers with EMA
- override /callbacks: [default, ema, flop-count]

model:
_target_: timm.models.vit_base_patch16_224
drop_path_rate: 0.1
num_classes: ${datamodule:num_classes}

datamodule:
# RepeatedAug is crucial for ViT-B as it seems to overfit.
num_aug_repeats: 3

callbacks:
ema:
decay: 0.99996
15 changes: 15 additions & 0 deletions configs/experiment/imagenet/deit/deit-s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package _global_
defaults:
- /experiment/imagenet/deit/base.yaml

seed: 1111

model:
_target_: timm.models.vit_small_patch16_224
drop_path_rate: 0.1
num_classes: ${datamodule:num_classes}

datamodule:
# DeiT paper says they use RepeatedAug, but I get 79.7% with RepeatedAug and
# 80.1% without.
num_aug_repeats: 0
9 changes: 7 additions & 2 deletions src/datamodules/imagenet.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning import LightningDataModule

@@ -188,11 +189,15 @@ def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:

def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
# If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to
# construct the DistributedSampler ourselves.
sampler = DistributedSampler(self.dataset_val) if self.num_aug_repeats != 0 else None
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, sampler=sampler)

def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)
sampler = DistributedSampler(self.dataset_test) if self.num_aug_repeats != 0 else None
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler)

def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
mixup: Optional[Callable] = None, sampler=None) -> DataLoader:
227 changes: 227 additions & 0 deletions src/models/vit/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
from functools import partial
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_

from timm.models.helpers import named_apply
from timm.models.layers import PatchEmbed

from src.models.modules.mha import MHA
from src.models.modules.mlp import Mlp, FusedDenseGeluDense
from src.models.modules.block import Block


def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, bf16):
mixer_cls = partial(MHA, num_heads=num_heads, bias=qkv_bias, dropout=attn_drop,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, bf16=bf16)
return mixer_cls


def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16):
inner_dim = int(embed_dim * mlp_ratio)
if not fused_dense_gelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, bf16=bf16)
return mlp_cls


def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path,
norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense,
fused_dropout_add_ln, bf16):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn,
fused_bias_fc, bf16)
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16)
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout=drop_rate, drop_path=drop_path,
fused_dropout_add_ln=fused_dropout_add_ln)
return block


class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_dense_gelu_dense=False,
fused_dropout_add_ln=False,
bf16=False
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token'
assert class_token
assert init_values is None, 'LayerScale is not supported yet'
assert weight_init == ''
assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU

self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class

self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule

# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# self.norm_0 is the first layer norm in the model, while self.norm
# (in the pretrained weight) is the final layer norm.
self.norm_0 = norm_layer(embed_dim)

self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
fused_dropout_add_ln=fused_dropout_add_ln, bf16=bf16
) for i in range(depth)])

# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

self.init_weights(weight_init)

def init_weights(self, mode=''):
assert mode == ''
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)

def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)

@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}

def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return self.pos_drop(x)

def forward_features(self, x):
x = self.patch_embed(x)
residual = self._pos_embed(x)
hidden_states = self.norm_0(residual)
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
return hidden_states

def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 0]
return x if pre_logits else self.head(x)

def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x


def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()


def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert not pretrained
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = VisionTransformer(**model_kwargs)
return model

0 comments on commit 1417543

Please sign in to comment.