diff --git a/configs/ViTDet/ViTDet-ViTAE-Small-100e.py b/configs/ViTDet/ViTDet-ViTAE-Small-100e.py index 632dedd..539b912 100644 --- a/configs/ViTDet/ViTDet-ViTAE-Small-100e.py +++ b/configs/ViTDet/ViTDet-ViTAE-Small-100e.py @@ -35,6 +35,7 @@ NC_group=[1, 1, 96], RC_group=[1, 1, 1], mlp_ratio=3., + window=[2, 5, 8, 13] ), neck=dict( type='FPN', diff --git a/mmdet/models/backbones/vitae_s.py b/mmdet/models/backbones/vitae_s.py index 7f44fd6..4e89b91 100644 --- a/mmdet/models/backbones/vitae_s.py +++ b/mmdet/models/backbones/vitae_s.py @@ -40,7 +40,7 @@ def forward(self, x, size): class BasicLayer(nn.Module): def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downsample_ratios=4, kernel_size=7, RC_heads=1, NC_heads=6, dilations=[1, 2, 3, 4], RC_tokens_type='performer', NC_tokens_type='transformer', RC_group=1, NC_group=64, NC_depth=2, dpr=0.1, mlp_ratio=4., qkv_bias=True, - qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, window_size=(14, 14), use_checkpoint=False, interval=3): + qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, window_size=(14, 14), use_checkpoint=False, window=[2, 5, 8, 13]): super().__init__() self.img_size = img_size self.in_chans = in_chans @@ -72,8 +72,8 @@ def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downs NormalCell(token_dims, NC_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=dpr[i] if isinstance(dpr, list) else dpr, norm_layer=norm_layer, group=NC_group, tokens_type=NC_tokens_type, - img_size=img_size // downsample_ratios, window_size=window_size if (i + 1) % interval != 0 else full_image_size, - window=True if (i + 1) % interval != 0 else False) + img_size=img_size // downsample_ratios, window_size=window_size if i in window else full_image_size, + window=True if i in window else False) for i in range(NC_depth)]) def forward(self, x, size): @@ -129,7 +129,7 @@ def __init__(self, frozen_stages=-1, use_checkpoint=False, load_ema=True, - interval=3, + window=[2, 5, 8, 13], pretrained=None): super().__init__() @@ -169,7 +169,7 @@ def __init__(self, self.kernel_size[i], self.RC_heads[i], self.NC_heads[i], self.dilaions[i], self.RC_tokens_type[i], self.NC_tokens_type[i], self.RC_group[i], self.NC_group[i], self.NC_depth[i], dpr[startDpr:self.NC_depth[i]+startDpr], mlp_ratio=self.mlp_ratio[i], qkv_bias=self.qkv_bias[i], qk_scale=self.qk_scale[i], drop=self.drop[i], attn_drop=self.attn_drop[i], - norm_layer=self.norm_layer[i], window_size=window_size, use_checkpoint=use_checkpoint, interval=interval) + norm_layer=self.norm_layer[i], window_size=window_size, use_checkpoint=use_checkpoint, window=window) ) img_size = img_size // self.downsample_ratios[i] in_chans = self.tokens_dims[i]