-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathbuild.py
43 lines (41 loc) · 1.53 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from timm.models import create_model
from . import focalnet
def build_model(config):
model_type = config.MODEL.TYPE
is_pretrained = config.MODEL.PRETRAINED
print(f"Creating model: {model_type}")
if "focal" in model_type:
model = create_model(
model_type,
pretrained=is_pretrained,
img_size=config.DATA.IMG_SIZE,
num_classes=config.MODEL.NUM_CLASSES,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
focal_levels=config.MODEL.FOCAL.FOCAL_LEVELS,
focal_windows=config.MODEL.FOCAL.FOCAL_WINDOWS,
use_conv_embed=config.MODEL.FOCAL.USE_CONV_EMBED,
use_layerscale=config.MODEL.FOCAL.USE_LAYERSCALE,
use_postln=config.MODEL.FOCAL.USE_POSTLN,
use_postln_in_modulation=config.MODEL.FOCAL.USE_POSTLN_IN_MODULATION,
normalize_modulator=config.MODEL.FOCAL.NORMALIZE_MODULATOR,
)
elif "vit" in model_type:
model = create_model(
model_type,
pretrained=is_pretrained,
img_size=config.DATA.IMG_SIZE,
num_classes=config.MODEL.NUM_CLASSES,
)
elif "resnet" in model_type:
model = create_model(
model_type,
pretrained=is_pretrained,
num_classes=config.MODEL.NUM_CLASSES
)
else:
model = create_model(
model_type,
pretrained=is_pretrained,
num_classes=config.MODEL.NUM_CLASSES
)
return model