-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfour_simEps_model.py
28 lines (21 loc) · 1.05 KB
/
four_simEps_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#Import the required libraries
import os
import torch
import torch.nn as nn
import functools
from utilities.m_net_factory import net_factory
from utilities.model_initialization import kaiming_normal_init_weight, xavier_normal_init_weight, xavier_uniform_init_weight, sparse_init_weight
# model_1 = nn.DataParallel(kaiming_normal_init_weight(net_factory(net_type='Unet', in_chns=3, class_num=4)))
# model_2 = nn.DataParallel(xavier_normal_init_weight(net_factory(net_type='Unet', in_chns=3, class_num=4)))
model1 = net_factory(net_type='unet_f', in_chns=3, class_num=4)
model2 = net_factory(net_type='unet_f', in_chns=3, class_num=4)
model3 = net_factory(net_type='unet_f', in_chns=3, class_num=4)
model4 = net_factory(net_type='unet_f', in_chns=3, class_num=4)
# model2 = xavier_normal_init_weight(model2)
# model3 = sparse_init_weight(model3)
# model4 = kaiming_normal_init_weight(model4)
model1 = nn.DataParallel(model1)
model2 = nn.DataParallel(model2)
model3 = nn.DataParallel(model3)
model4 = nn.DataParallel(model4)
# print(model1)