diff --git a/configs/experiment/imagenet/deit/base.yaml b/configs/experiment/imagenet/deit/base.yaml new file mode 100644 index 0000000..7da547a --- /dev/null +++ b/configs/experiment/imagenet/deit/base.yaml @@ -0,0 +1,73 @@ +# @package _global_ +defaults: + - override /trainer: default # choose trainer from 'configs/trainer/' + - override /model: null + - override /datamodule: imagenet + - override /optimizer: adamw + - override /scheduler: null + - override /callbacks: default + - override /metrics: [acc, acctop5] + - override /logger: wandb + +seed: 1111 + +trainer: + accelerator: gpu + devices: 8 + num_nodes: 1 + accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${.devices} * ${datamodule.batch_size} * ${.num_nodes}}} + max_epochs: 300 + precision: 16 + replace_sampler_ddp: ${eval:"${datamodule.num_aug_repeats} == 0"} + +datamodule: + batch_size: 128 # Per GPU + batch_size_eval: ${eval:${.batch_size} * 2} + image_size: 224 + train_transforms: + _target_: timm.data.create_transform + input_size: ${datamodule.image_size} + is_training: True + auto_augment: rand-m9-mstd0.5-inc1 # Use AutoAugment policy + interpolation: bicubic + re_prob: 0.25 # Random erase prob + re_mode: pixel # Random erase mode + val_transforms: + _target_: timm.data.create_transform + input_size: ${datamodule.image_size} + interpolation: bicubic + crop_pct: 0.9 + test_transforms: ${.val_transforms} + mixup: + _target_: src.datamodules.timm_mixup.TimmMixup + mixup_alpha: 0.8 + cutmix_alpha: 1.0 + label_smoothing: 0.0 # We're using label smoothing from Pytorch's CrossEntropyLoss + # DeiT paper says they use RepeatedAug, but for ViT-S I get 79.7% with RepeatedAug and + # 80.1% without. + num_aug_repeats: 0 + num_workers: 11 # For A100 we need a lot of workers, V100 night need fewer + +train: + global_batch_size: 1024 + num_steps_per_epoch: ${div_up:${datamodule.__train_len}, ${train.global_batch_size}} + optimizer: + lr: ${eval:5e-4 * ${train.global_batch_size} / 512} + weight_decay: 0.05 + optimizer_param_grouping: + bias_weight_decay: False + normalization_weight_decay: False + scheduler: + _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler + t_initial: ${eval:${trainer.max_epochs} * ${train.num_steps_per_epoch}} + lr_min: 1e-5 + warmup_lr_init: 1e-6 + warmup_t: ${eval:5 * ${train.num_steps_per_epoch}} + cycle_limit: 1 + t_in_epochs: False + scheduler_interval: step + loss_fn: + _target_: torch.nn.CrossEntropyLoss + label_smoothing: 0.1 + loss_fn_val: + _target_: torch.nn.CrossEntropyLoss diff --git a/configs/experiment/imagenet/deit/deit-b-flash.yaml b/configs/experiment/imagenet/deit/deit-b-flash.yaml new file mode 100644 index 0000000..ba79c55 --- /dev/null +++ b/configs/experiment/imagenet/deit/deit-b-flash.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - /experiment/imagenet/deit/deit-b.yaml + - override /optimizer: adamw-apex + - override /callbacks: [default, ema] + +model: + _target_: src.models.vit.vit.vit_base_patch16_224 + drop_path_rate: 0.1 + use_flash_attn: True + fused_bias_fc: True + fused_dense_gelu_dense: True + fused_dropout_add_ln: True + bf16: ${eval:'"${trainer.precision}" == "bf16"'} + +# trainer: +# strategy: deepspeed_stage_1 diff --git a/configs/experiment/imagenet/deit/deit-b.yaml b/configs/experiment/imagenet/deit/deit-b.yaml new file mode 100644 index 0000000..e6a1b23 --- /dev/null +++ b/configs/experiment/imagenet/deit/deit-b.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - /experiment/imagenet/deit/base.yaml + # TD [2022-05-27]: the DeiT paper says they don't use EMA but I'm only able to + # replicate their numbers with EMA + - override /callbacks: [default, ema, flop-count] + +model: + _target_: timm.models.vit_base_patch16_224 + drop_path_rate: 0.1 + num_classes: ${datamodule:num_classes} + +datamodule: + # RepeatedAug is crucial for ViT-B as it seems to overfit. + num_aug_repeats: 3 + +callbacks: + ema: + decay: 0.99996 diff --git a/configs/experiment/imagenet/deit/deit-s.yaml b/configs/experiment/imagenet/deit/deit-s.yaml new file mode 100644 index 0000000..5579759 --- /dev/null +++ b/configs/experiment/imagenet/deit/deit-s.yaml @@ -0,0 +1,15 @@ +# @package _global_ +defaults: + - /experiment/imagenet/deit/base.yaml + +seed: 1111 + +model: + _target_: timm.models.vit_small_patch16_224 + drop_path_rate: 0.1 + num_classes: ${datamodule:num_classes} + +datamodule: + # DeiT paper says they use RepeatedAug, but I get 79.7% with RepeatedAug and + # 80.1% without. + num_aug_repeats: 0 diff --git a/src/datamodules/imagenet.py b/src/datamodules/imagenet.py index ddbbdea..db73ab4 100644 --- a/src/datamodules/imagenet.py +++ b/src/datamodules/imagenet.py @@ -5,6 +5,7 @@ from torch.utils.data import Dataset, DataLoader from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import LightningDataModule @@ -188,11 +189,15 @@ def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ - return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) + # If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to + # construct the DistributedSampler ourselves. + sampler = DistributedSampler(self.dataset_val) if self.num_aug_repeats != 0 else None + return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, sampler=sampler) def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ - return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) + sampler = DistributedSampler(self.dataset_test) if self.num_aug_repeats != 0 else None + return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler) def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, mixup: Optional[Callable] = None, sampler=None) -> DataLoader: diff --git a/src/models/vit/vit.py b/src/models/vit/vit.py new file mode 100644 index 0000000..4c7a7fb --- /dev/null +++ b/src/models/vit/vit.py @@ -0,0 +1,227 @@ +# Copyright (c) 2022, Tri Dao. +# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +import math +from functools import partial +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from timm.models.helpers import named_apply +from timm.models.layers import PatchEmbed + +from src.models.modules.mha import MHA +from src.models.modules.mlp import Mlp, FusedDenseGeluDense +from src.models.modules.block import Block + + +def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, bf16): + mixer_cls = partial(MHA, num_heads=num_heads, bias=qkv_bias, dropout=attn_drop, + fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, bf16=bf16) + return mixer_cls + + +def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16): + inner_dim = int(embed_dim * mlp_ratio) + if not fused_dense_gelu_dense: + mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) + else: + mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, bf16=bf16) + return mlp_cls + + +def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path, + norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense, + fused_dropout_add_ln, bf16): + mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, + fused_bias_fc, bf16) + mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense, bf16) + block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, + prenorm=True, resid_dropout=drop_rate, drop_path=drop_path, + fused_dropout_add_ln=fused_dropout_add_ln) + return block + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + use_flash_attn=False, + fused_bias_fc=False, + fused_dense_gelu_dense=False, + fused_dropout_add_ln=False, + bf16=False + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool == 'token', 'Only support pooling with CLS token' + assert class_token + assert init_values is None, 'LayerScale is not supported yet' + assert weight_init == '' + assert fc_norm is None + # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk + assert not pre_norm + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and + # the main branch (output of LN). The model definition is unchanged, but the mapping of the + # nn.LayerNorm weights are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + # self.norm_0 is the first layer norm in the model, while self.norm + # (in the pretrained weight) is the final layer norm. + self.norm_0 = norm_layer(embed_dim) + + self.blocks = nn.ModuleList([create_block( + embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, + fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense, + fused_dropout_add_ln=fused_dropout_add_ln, bf16=bf16 + ) for i in range(depth)]) + + # Classifier Head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode == '' + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + + def forward_features(self, x): + x = self.patch_embed(x) + residual = self._pos_embed(x) + hidden_states = self.norm_0(residual) + for block in self.blocks: + hidden_states, residual = block(hidden_states, residual) + return hidden_states + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + assert not pretrained + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = VisionTransformer(**model_kwargs) + return model