Skip to content

Commit

Permalink
register ViTAE_S
Browse files Browse the repository at this point in the history
  • Loading branch information
Annbless committed Apr 19, 2022
1 parent 640c567 commit b5380f6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion configs/ViTDet/ViTDet-ViTAE-Small-100e.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
NC_group=[1, 1, 96],
RC_group=[1, 1, 1],
mlp_ratio=3.,
window=[2, 5, 8, 13]
globalBlock=[2, 5, 8, 13]
),
neck=dict(
type='FPN',
Expand Down
10 changes: 5 additions & 5 deletions mmdet/models/backbones/vitae_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, x, size):
class BasicLayer(nn.Module):
def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downsample_ratios=4, kernel_size=7, RC_heads=1, NC_heads=6, dilations=[1, 2, 3, 4],
RC_tokens_type='performer', NC_tokens_type='transformer', RC_group=1, NC_group=64, NC_depth=2, dpr=0.1, mlp_ratio=4., qkv_bias=True,
qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, window_size=(14, 14), use_checkpoint=False, window=[2, 5, 8, 13]):
qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, window_size=(14, 14), use_checkpoint=False, globalBlock=[2, 5, 8, 13]):
super().__init__()
self.img_size = img_size
self.in_chans = in_chans
Expand Down Expand Up @@ -72,8 +72,8 @@ def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downs
NormalCell(token_dims, NC_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=dpr[i] if isinstance(dpr, list) else dpr, norm_layer=norm_layer,
group=NC_group, tokens_type=NC_tokens_type,
img_size=img_size // downsample_ratios, window_size=window_size if i in window else full_image_size,
window=True if i in window else False)
img_size=img_size // downsample_ratios, window_size=window_size if i not in globalBlock else full_image_size,
window=True if i not in globalBlock else False)
for i in range(NC_depth)])

def forward(self, x, size):
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(self,
frozen_stages=-1,
use_checkpoint=False,
load_ema=True,
window=[2, 5, 8, 13],
globalBlock=[2, 5, 8, 13],
pretrained=None):
super().__init__()

Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self,
self.kernel_size[i], self.RC_heads[i], self.NC_heads[i], self.dilaions[i],
self.RC_tokens_type[i], self.NC_tokens_type[i], self.RC_group[i], self.NC_group[i], self.NC_depth[i], dpr[startDpr:self.NC_depth[i]+startDpr],
mlp_ratio=self.mlp_ratio[i], qkv_bias=self.qkv_bias[i], qk_scale=self.qk_scale[i], drop=self.drop[i], attn_drop=self.attn_drop[i],
norm_layer=self.norm_layer[i], window_size=window_size, use_checkpoint=use_checkpoint, window=window)
norm_layer=self.norm_layer[i], window_size=window_size, use_checkpoint=use_checkpoint, globalBlock=globalBlock)
)
img_size = img_size // self.downsample_ratios[i]
in_chans = self.tokens_dims[i]
Expand Down

0 comments on commit b5380f6

Please sign in to comment.