-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
358 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# @package _global_ | ||
defaults: | ||
- override /trainer: default # choose trainer from 'configs/trainer/' | ||
- override /model: null | ||
- override /datamodule: imagenet | ||
- override /optimizer: adamw | ||
- override /scheduler: null | ||
- override /callbacks: default | ||
- override /metrics: [acc, acctop5] | ||
- override /logger: wandb | ||
|
||
seed: 1111 | ||
|
||
trainer: | ||
accelerator: gpu | ||
devices: 8 | ||
num_nodes: 1 | ||
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${.devices} * ${datamodule.batch_size} * ${.num_nodes}}} | ||
max_epochs: 300 | ||
precision: 16 | ||
replace_sampler_ddp: ${eval:"${datamodule.num_aug_repeats} == 0"} | ||
|
||
datamodule: | ||
batch_size: 128 # Per GPU | ||
batch_size_eval: ${eval:${.batch_size} * 2} | ||
image_size: 224 | ||
train_transforms: | ||
_target_: timm.data.create_transform | ||
input_size: ${datamodule.image_size} | ||
is_training: True | ||
auto_augment: rand-m9-mstd0.5-inc1 # Use AutoAugment policy | ||
interpolation: bicubic | ||
re_prob: 0.25 # Random erase prob | ||
re_mode: pixel # Random erase mode | ||
val_transforms: | ||
_target_: timm.data.create_transform | ||
input_size: ${datamodule.image_size} | ||
interpolation: bicubic | ||
crop_pct: 0.9 | ||
test_transforms: ${.val_transforms} | ||
mixup: | ||
_target_: src.datamodules.timm_mixup.TimmMixup | ||
mixup_alpha: 0.8 | ||
cutmix_alpha: 1.0 | ||
label_smoothing: 0.0 # We're using label smoothing from Pytorch's CrossEntropyLoss | ||
# DeiT paper says they use RepeatedAug, but for ViT-S I get 79.7% with RepeatedAug and | ||
# 80.1% without. | ||
num_aug_repeats: 0 | ||
num_workers: 11 # For A100 we need a lot of workers, V100 night need fewer | ||
|
||
train: | ||
global_batch_size: 1024 | ||
num_steps_per_epoch: ${div_up:${datamodule.__train_len}, ${train.global_batch_size}} | ||
optimizer: | ||
lr: ${eval:5e-4 * ${train.global_batch_size} / 512} | ||
weight_decay: 0.05 | ||
optimizer_param_grouping: | ||
bias_weight_decay: False | ||
normalization_weight_decay: False | ||
scheduler: | ||
_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler | ||
t_initial: ${eval:${trainer.max_epochs} * ${train.num_steps_per_epoch}} | ||
lr_min: 1e-5 | ||
warmup_lr_init: 1e-6 | ||
warmup_t: ${eval:5 * ${train.num_steps_per_epoch}} | ||
cycle_limit: 1 | ||
t_in_epochs: False | ||
scheduler_interval: step | ||
loss_fn: | ||
_target_: torch.nn.CrossEntropyLoss | ||
label_smoothing: 0.1 | ||
loss_fn_val: | ||
_target_: torch.nn.CrossEntropyLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
defaults: | ||
- /experiment/imagenet/deit/deit-b.yaml | ||
- override /optimizer: adamw-apex | ||
- override /callbacks: [default, ema] | ||
|
||
model: | ||
_target_: src.models.vit.vit.vit_base_patch16_224 | ||
drop_path_rate: 0.1 | ||
use_flash_attn: True | ||
fused_bias_fc: True | ||
fused_dense_gelu_dense: True | ||
fused_dropout_add_ln: True | ||
bf16: ${eval:'"${trainer.precision}" == "bf16"'} | ||
|
||
# trainer: | ||
# strategy: deepspeed_stage_1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _global_ | ||
defaults: | ||
- /experiment/imagenet/deit/base.yaml | ||
# TD [2022-05-27]: the DeiT paper says they don't use EMA but I'm only able to | ||
# replicate their numbers with EMA | ||
- override /callbacks: [default, ema, flop-count] | ||
|
||
model: | ||
_target_: timm.models.vit_base_patch16_224 | ||
drop_path_rate: 0.1 | ||
num_classes: ${datamodule:num_classes} | ||
|
||
datamodule: | ||
# RepeatedAug is crucial for ViT-B as it seems to overfit. | ||
num_aug_repeats: 3 | ||
|
||
callbacks: | ||
ema: | ||
decay: 0.99996 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _global_ | ||
defaults: | ||
- /experiment/imagenet/deit/base.yaml | ||
|
||
seed: 1111 | ||
|
||
model: | ||
_target_: timm.models.vit_small_patch16_224 | ||
drop_path_rate: 0.1 | ||
num_classes: ${datamodule:num_classes} | ||
|
||
datamodule: | ||
# DeiT paper says they use RepeatedAug, but I get 79.7% with RepeatedAug and | ||
# 80.1% without. | ||
num_aug_repeats: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
# Copyright (c) 2022, Tri Dao. | ||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | ||
import math | ||
from functools import partial | ||
from copy import deepcopy | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.init import trunc_normal_ | ||
|
||
from timm.models.helpers import named_apply | ||
from timm.models.layers import PatchEmbed | ||
|
||
from src.models.modules.mha import MHA | ||
from src.models.modules.mlp import Mlp, FusedDenseGeluDense | ||
from src.models.modules.block import Block | ||
|
||
|
||
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, bf16): | ||
mixer_cls = partial(MHA, num_heads=num_heads, bias=qkv_bias, dropout=attn_drop, | ||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, bf16=bf16) | ||
return mixer_cls | ||
|
||
|
||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16): | ||
inner_dim = int(embed_dim * mlp_ratio) | ||
if not fused_dense_gelu_dense: | ||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) | ||
else: | ||
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, bf16=bf16) | ||
return mlp_cls | ||
|
||
|
||
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path, | ||
norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense, | ||
fused_dropout_add_ln, bf16): | ||
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, | ||
fused_bias_fc, bf16) | ||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16) | ||
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, | ||
prenorm=True, resid_dropout=drop_rate, drop_path=drop_path, | ||
fused_dropout_add_ln=fused_dropout_add_ln) | ||
return block | ||
|
||
|
||
class VisionTransformer(nn.Module): | ||
""" Vision Transformer | ||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` | ||
- https://arxiv.org/abs/2010.11929 | ||
""" | ||
def __init__( | ||
self, | ||
img_size=224, | ||
patch_size=16, | ||
in_chans=3, | ||
num_classes=1000, | ||
global_pool='token', | ||
embed_dim=768, | ||
depth=12, | ||
num_heads=12, | ||
mlp_ratio=4., | ||
qkv_bias=True, | ||
init_values=None, | ||
class_token=True, | ||
no_embed_class=False, | ||
pre_norm=False, | ||
fc_norm=None, | ||
drop_rate=0., | ||
attn_drop_rate=0., | ||
drop_path_rate=0., | ||
weight_init='', | ||
embed_layer=PatchEmbed, | ||
norm_layer=None, | ||
act_layer=None, | ||
use_flash_attn=False, | ||
fused_bias_fc=False, | ||
fused_dense_gelu_dense=False, | ||
fused_dropout_add_ln=False, | ||
bf16=False | ||
): | ||
""" | ||
Args: | ||
img_size (int, tuple): input image size | ||
patch_size (int, tuple): patch size | ||
in_chans (int): number of input channels | ||
num_classes (int): number of classes for classification head | ||
global_pool (str): type of global pooling for final sequence (default: 'token') | ||
embed_dim (int): embedding dimension | ||
depth (int): depth of transformer | ||
num_heads (int): number of attention heads | ||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||
qkv_bias (bool): enable bias for qkv if True | ||
init_values: (float): layer-scale init values | ||
class_token (bool): use class token | ||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) | ||
drop_rate (float): dropout rate | ||
attn_drop_rate (float): attention dropout rate | ||
drop_path_rate (float): stochastic depth rate | ||
weight_init (str): weight init scheme | ||
embed_layer (nn.Module): patch embedding layer | ||
norm_layer: (nn.Module): normalization layer | ||
act_layer: (nn.Module): MLP activation layer | ||
""" | ||
super().__init__() | ||
assert global_pool == 'token', 'Only support pooling with CLS token' | ||
assert class_token | ||
assert init_values is None, 'LayerScale is not supported yet' | ||
assert weight_init == '' | ||
assert fc_norm is None | ||
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk | ||
assert not pre_norm | ||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm | ||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | ||
act_layer = act_layer or nn.GELU | ||
|
||
self.num_classes = num_classes | ||
self.global_pool = global_pool | ||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | ||
self.num_prefix_tokens = 1 if class_token else 0 | ||
self.no_embed_class = no_embed_class | ||
|
||
self.patch_embed = embed_layer( | ||
img_size=img_size, | ||
patch_size=patch_size, | ||
in_chans=in_chans, | ||
embed_dim=embed_dim, | ||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) | ||
) | ||
num_patches = self.patch_embed.num_patches | ||
|
||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None | ||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens | ||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) | ||
self.pos_drop = nn.Dropout(p=drop_rate) | ||
|
||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | ||
|
||
# We change the order of residual and layer norm: | ||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do: | ||
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and | ||
# the main branch (output of LN). The model definition is unchanged, but the mapping of the | ||
# nn.LayerNorm weights are changed. | ||
# This is for performance reason: we can fuse dropout + add + layer_norm. | ||
# self.norm_0 is the first layer norm in the model, while self.norm | ||
# (in the pretrained weight) is the final layer norm. | ||
self.norm_0 = norm_layer(embed_dim) | ||
|
||
self.blocks = nn.ModuleList([create_block( | ||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i], | ||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, | ||
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense, | ||
fused_dropout_add_ln=fused_dropout_add_ln, bf16=bf16 | ||
) for i in range(depth)]) | ||
|
||
# Classifier Head | ||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | ||
|
||
self.init_weights(weight_init) | ||
|
||
def init_weights(self, mode=''): | ||
assert mode == '' | ||
trunc_normal_(self.pos_embed, std=.02) | ||
if self.cls_token is not None: | ||
nn.init.normal_(self.cls_token, std=1e-6) | ||
named_apply(init_weights_vit_timm, self) | ||
|
||
def _init_weights(self, m): | ||
# this fn left here for compat with downstream users | ||
init_weights_vit_timm(m) | ||
|
||
@torch.jit.ignore | ||
def no_weight_decay(self): | ||
return {'pos_embed', 'cls_token'} | ||
|
||
def _pos_embed(self, x): | ||
if self.no_embed_class: | ||
# deit-3, updated JAX (big vision) | ||
# position embedding does not overlap with class token, add then concat | ||
x = x + self.pos_embed | ||
if self.cls_token is not None: | ||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) | ||
else: | ||
# original timm, JAX, and deit vit impl | ||
# pos_embed has entry for class token, concat then add | ||
if self.cls_token is not None: | ||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) | ||
x = x + self.pos_embed | ||
return self.pos_drop(x) | ||
|
||
def forward_features(self, x): | ||
x = self.patch_embed(x) | ||
residual = self._pos_embed(x) | ||
hidden_states = self.norm_0(residual) | ||
for block in self.blocks: | ||
hidden_states, residual = block(hidden_states, residual) | ||
return hidden_states | ||
|
||
def forward_head(self, x, pre_logits: bool = False): | ||
if self.global_pool: | ||
x = x[:, 0] | ||
return x if pre_logits else self.head(x) | ||
|
||
def forward(self, x): | ||
x = self.forward_features(x) | ||
x = self.forward_head(x) | ||
return x | ||
|
||
|
||
def init_weights_vit_timm(module: nn.Module, name: str = ''): | ||
""" ViT weight initialization, original timm impl (for reproducibility) """ | ||
if isinstance(module, nn.Linear): | ||
trunc_normal_(module.weight, std=.02) | ||
if module.bias is not None: | ||
nn.init.zeros_(module.bias) | ||
elif hasattr(module, 'init_weights'): | ||
module.init_weights() | ||
|
||
|
||
def vit_base_patch16_224(pretrained=False, **kwargs): | ||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). | ||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. | ||
""" | ||
assert not pretrained | ||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) | ||
model = VisionTransformer(**model_kwargs) | ||
return model |