-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ttyd_core.py
305 lines (214 loc) · 12.1 KB
/
train_ttyd_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from utils_source_free.general_imports import *
def main(config_arguments):
#Setting the seeds
torch.manual_seed(1234)
random.seed(1234)
np.random.seed(1234)
# define the logging
torch.autograd.set_detect_anomaly(True)
writer = SummaryWriter('runs_eccv/{}'.format(f"{config_arguments['tensorboard_folder']}/{config_arguments['name']}"))
file_path_config = os.path.join(config_arguments["resume_path"], "config.yaml")
config = read_yaml_file(file_path_config)
if config is not None:
# Access the data in the dictionary
print(file_path_config)
print(f"Loaded Config: {config}")
else:
print("Failed to read the config YAML file.")
return 0
logging.getLogger().setLevel(config["logging"])
config["parameter"] = config_arguments
##############################################################################################
#Selection of the setting that is used
config = config_adapter(config)
config["ignore_class"] = 0
mapping_info = sf_class_mapping_loader(source_dataset=config["source_dataset_name"], target_dataset=config["target_dataset_name"])
summation_matrix = summation_matrix_generator(mapping_info)
##############################################################################################
### Iterate over the additional arguments
for k,v in config_arguments.items():
logging.info(f"{k}: {v}")
logging.info("Creating the network")
config['training_batch_size']= config["parameter"]["batch_size"]
config["test_batch_size"]=16
savedir_root=f"ckpts_bn/{config_arguments['name']}"
os.makedirs(savedir_root, exist_ok=True)
config["ns_dataset_version"] = 'v1.0-trainval'
config["network_backbone"] = 'TorchSparseMinkUNet_learned'
name_shift_inverse = {}
for key,value in name_shift.items():
name_shift_inverse[value]=key
config["da_fixed_head_path_model"]=config["parameter"]["resume_path"]
# device
device = torch.device(config['device'])
if config["device"] == "cuda":
torch.backends.cudnn.benchmark = True
bb_dir_root = get_bbdir_root(config)
# create the network
latent_size = config["network_latent_size"]
backbone = config["network_backbone"]
in_channels_source, _, in_channels_target, _ = da_get_inputs(config)
logging.info("Creating the network")
def network_function():
return networks.Network(in_channels=in_channels_source, latent_size=latent_size, backbone=backbone,\
voxel_size=config["voxel_size"], dual_seg_head = config["dual_seg_head"], target_in_channels=in_channels_target, config=config)
### Final network
net_final = network_function()
ckpt_path = os.path.join(bb_dir_root, 'source_only.pth')
logging.info(f"CKPT -- Load ckpt from {ckpt_path}")
#Load the checkpoint for the backbone
checkpoint = torch.load(ckpt_path, map_location=device)
#Updating the checkpoint
checkpoint_new = {}
for key in checkpoint["state_dict"].keys():
if key in name_shift_inverse:
checkpoint_new[name_shift_inverse[key]] = checkpoint["state_dict"][key]
else:
if "num_batches_tracked" in key or "point_transforms" in key:
pass
else:
checkpoint_new[key]= checkpoint["state_dict"][key]
try:
net_final.load_state_dict(checkpoint_new)
except Exception as e:
logging.info(f"Loaded parameters do not match exactly net architecture, switching to load_state_dict strict=false.")
net_final.load_state_dict(checkpoint_new, strict=False)
logging.info(f"Network -- Number of parameters {count_parameters(net_final)}")
target_DatasetClass = get_dataset(eval("datasets."+config["target_dataset_name"]))
val_number = 1 #1: verifying split, 2 train split, else: test split
dataloader_dict = da_sf_get_dataloader(target_DatasetClass, config, net_final, network_function, val=val_number, train_shuffle=True, keep_orignal_data=False)
target_train_loader = dataloader_dict ["target_train_loader"]
target_test_loader = dataloader_dict ["target_test_loader"]
os.makedirs(savedir_root, exist_ok=True)
save_config_file(eval(str(config)), os.path.join(savedir_root, "config.yaml"))
# create the loss layer
loss_layer = torch.nn.BCEWithLogitsLoss()
weights_ss = torch.ones(config["nb_classes_inference"])
list_ignore_classes = ignore_selection(config["ignore_idx"])
for idx_ignore_class in list_ignore_classes:
weights_ss[idx_ignore_class] = 0
logging.info(f"Ignored classes {list_ignore_classes}")
logging.info(f"Weights of the different classes {weights_ss}")
weights_ss= weights_ss.to(device)
ce_loss_layer = torch.nn.CrossEntropyLoss(weight = weights_ss)
net_final.eval()
net_final.to(device)
list_parameter_to_update = []
list_parameter_others = [] #2nd section of selected parameters, e.g. if scaling LL and Backbone differently
net_final, list_parameter_to_update, list_parameter_others = \
configure_freeze_models(net_final, config, list_parameter_to_update, list_parameter_others)
for l_name, l_module in net_final.named_modules():
if isinstance(l_module, torch.nn.modules.batchnorm._BatchNorm):
l_module.eval()
class_prior = np.zeros((1))
class_prior, names_list = class_prior_class_names(config, logging)
#Renormalize the class_prior to have numerical exactly a sum of 1
class_prior = class_prior / np.sum(class_prior)
logging.info(f"We use a distribution of {class_prior}")
ent_loss_thr = np.array(config["parameter"]["ent_loss_thr"]).astype(np.float64)
ent_loss_thr = torch.from_numpy(ent_loss_thr).type(torch.FloatTensor).to(device)
div_loss_thr = np.array(config["parameter"]["div_loss_thr"]).astype(np.float64)
div_loss_thr = torch.from_numpy(div_loss_thr).type(torch.FloatTensor).to(device)
summation_matrix = summation_matrix.to(device)
class_prior = torch.from_numpy(class_prior).to(device)
kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
net_final.to(device)
if config["parameter"]["finetune"] and config["parameter"]["fintune_setting"]=="classic":
logging.info("Classifier get updated with 10X higher LR than backbone.")
optimizer = torch.optim.AdamW([{"params": list_parameter_to_update, "lr":config["parameter"]["learning_rate"]},\
{"params": list_parameter_others, "lr":config["parameter"]["learning_rate"] / 10.0}]) #Backbone is updated with a 10x smaller learning rate
else:
optimizer = torch.optim.AdamW([{"params": list_parameter_to_update}],config["parameter"]["learning_rate"])
if config["parameter"]["lr_scheduler"]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20005, eta_min=0)
logging.info(f"Network -- Number of finally optimized parameters {count_parameters(net_final)}")
train_iter_trg = enumerate(target_train_loader)
for i in range(config["parameter"]["nb_iterations"]):
if i % config["parameter"]["val_intervall"]==0:
logging.info(i)
#if "code_test" in config and config["code_test"]:
return_data_val_target_mapped = \
validation_non_premap(net_final, config, target_test_loader, epoch=0, disable_log=False, device=device, list_ignore_classes=[0])
logging.info(f"mIoU: {return_data_val_target_mapped['test_seg_head_miou']}")
writer.add_scalar(f"validation.seg_mIou", return_data_val_target_mapped['test_seg_head_miou'], i)
logging.info(f"Per class {return_data_val_target_mapped['seg_iou_per_class']}")
for q in range(len(names_list)):
writer.add_scalar(f"validation.seg_Iou_{names_list[q]}", return_data_val_target_mapped['seg_iou_per_class'][q], i)
#After validation, set again the BN to the defined setting
for l_name, l_module in net_final.named_modules():
if isinstance(l_module, torch.nn.modules.batchnorm._BatchNorm):
l_module.eval()
if i % config["parameter"]["ckpt_intervall"]==0:
torch.save({"state_dict": net_final.state_dict()},os.path.join(savedir_root, f"model_{i}.pth"),)
try:
_, target_data = train_iter_trg.__next__()
except:
train_iter_trg = enumerate(target_train_loader)
_, target_data = train_iter_trg.__next__()
target_data = dict_to_device(target_data, device)
optimizer.zero_grad()
_, output_seg, _ = net_final.forward_mapped_learned(target_data)
loss_seg = None
#### Entropy loss
loss_ent = minent_entropy_loss(output_seg)
loss_ent = F.relu(loss_ent - ent_loss_thr, inplace=False)
loss_seg = loss_ent
writer.add_scalar(f"training.entropy_loss",loss_seg, i)
#### Diversity loss
nb_points = output_seg.shape[0]
#Mapping to the new class output
output_seg = output_seg[:,:,0]@summation_matrix
input = F.softmax(output_seg[:,1:], dim=1).sum(dim=0)/nb_points
input_log = torch.log(input)
loss_kl = kl_loss(input_log, class_prior).type(torch.FloatTensor)
loss_kl = F.relu(loss_kl - div_loss_thr, inplace=False)
div_loss = loss_kl
writer.add_scalar(f"training.diversity_loss", div_loss, i)
loss_seg = loss_seg+div_loss
writer.add_scalar(f"training.seg_loss", loss_seg, i)
loss_seg.backward()
optimizer.step()
del loss_seg
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
#General settings
parser.add_argument('--name', '-n', type=str, required=True)
parser.add_argument('--setting', '-ds', type=str, required=True, default="NS2SK")
parser.add_argument('--resume_path', '-p', type=str, default="cvpr24_results/REP0_ns_semantic_TorchSparseMinkUNet_InterpAllRadiusNoDirsNet_1.0_trainSplit")
parser.add_argument('--tensorboard_folder', '-tf', type=str, default="DASF")
parser.add_argument('--bn_layer', '-l', type=str, default="standard")
#Learning parameter
parser.add_argument('--learning_rate', '-lr', type=float, default=0.001)
parser.add_argument('--batch_size', '-bs', type=int, default=4)
parser.add_argument('--nb_iterations', '-i', type=int, default=20010)
parser.add_argument('--ckpt_intervall', type=int, default=1000)
parser.add_argument('--val_intervall', type=int, default=1000)
parser.add_argument('--lr_scheduler', '-ls', type=bool, default=False)
#Select what to finetune
parser.add_argument('--finetune', '-f', type=bool, default=False)
parser.add_argument('--fintune_setting', '-fs', type=str, choices=['LL', 'classic', 'll_and_scalable_finetune', 'shot_finetune', 'complete_finetune'], default='LL') #
#Clipping the loss
parser.add_argument('--ent_loss_thr', '-eth', type=float, default=0.04)
parser.add_argument('--div_loss_thr', '-dth', type=float, default=0.04)
opts = parser.parse_args()
config_arguments = {}
#Experiment credentials
config_arguments["name"] = opts.name
config_arguments["tensorboard_folder"] = opts.tensorboard_folder
config_arguments["resume_path"] = opts.resume_path
#Training settings
config_arguments["setting"] = opts.setting
config_arguments["bn_layer"] = opts.bn_layer
config_arguments["finetune"] = opts.finetune
config_arguments["fintune_setting"] = opts.fintune_setting
config_arguments["learning_rate"] = opts.learning_rate
config_arguments["batch_size"]= opts.batch_size
config_arguments["nb_iterations"] = opts.nb_iterations
#Evaluation settings
config_arguments["ckpt_intervall"] = opts.ckpt_intervall
config_arguments["val_intervall"] = opts.val_intervall
#Clipping
config_arguments["ent_loss_thr"] = opts.ent_loss_thr
config_arguments["div_loss_thr"] = opts.div_loss_thr
config_arguments["lr_scheduler"] = opts.lr_scheduler
main(config_arguments)