forked from xheon/panoptic-reconstruction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_test.py
126 lines (97 loc) · 4.09 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import time
from collections import OrderedDict
from pathlib import Path
import torch
from lib.structures.field_list import collect
from lib import utils, logger, config, modeling, solver, data
import os
import sys
sys.path.append('/usr/src/app/spsg/torch')
from utils.raycast_rgbd.raycast_rgbd import RaycastRGBD
config.merge_from_file('configs/front3d_train_3d.yaml')
model = modeling.PanopticReconstruction()
device = torch.device(config.MODEL.DEVICE)
model.to(device, non_blocking=True)
model.log_model_info()
model.fix_weights()
# Setup optimizer, scheduler, checkpointer
optimizer = torch.optim.Adam(model.parameters(), config.SOLVER.BASE_LR,
betas=(config.SOLVER.BETA_1, config.SOLVER.BETA_2),
weight_decay=config.SOLVER.WEIGHT_DECAY)
scheduler = solver.WarmupMultiStepLR(optimizer, config.SOLVER.STEPS, config.SOLVER.GAMMA,
warmup_factor=1,
warmup_iters=0,
warmup_method="linear")
model_dict = model.state_dict()
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of Trainable Parameters: {}".format(pytorch_total_params))
output_path = Path('output')
checkpointer = utils.DetectronCheckpointer(model, optimizer, scheduler, output_path)
# Load the checkpoint
checkpoint_data = checkpointer.load()
checkpoint_arguments = {}
checkpoint_arguments["iteration"] = 0
if config.SOLVER.LOAD_SCHEDULER:
checkpoint_arguments.update(checkpoint_data)
# TODO: move to checkpointer?
if config.MODEL.PRETRAIN2D:
pretrain_2d = torch.load(config.MODEL.PRETRAIN2D)
model.load_state_dict(pretrain_2d["model"])
# Dataloader
dataloader = data.setup_dataloader(config.DATASETS.TRAIN)
# Switch training mode
# self.model.switch_training()
print(len(dataloader))
model.switch_training()
iteration = 0
iteration_end = time.time()
for idx, (image_ids, targets) in enumerate(dataloader):
assert targets is not None, "error during data loading"
data_time = time.time() - iteration_end
# Get input images
images = collect(targets, "color")
# Pass through model
# try:
losses, results = model(images, targets)
# except Exception as e:
# print(e, "skipping", image_ids[0])
# del targets, images
# continue
# Accumulate total loss
total_loss: torch.Tensor = 0.0
log_meters = OrderedDict()
rgb_loss = 0.0
for loss_group in losses.values():
for loss_name, loss in loss_group.items():
if(loss_name == "rgb"):
# print("loss_name: {}, loss: {}".format(loss_name, loss))
rgb_loss = loss
if torch.is_tensor(loss) and not torch.isnan(loss) and not torch.isinf(loss):
total_loss += loss
log_meters[loss_name] = loss.item()
# Loss backpropagation, optimizer & scheduler step
optimizer.zero_grad()
if torch.is_tensor(total_loss):
total_loss.backward()
optimizer.step()
scheduler.step()
log_meters["total"] = total_loss.item()
else:
log_meters["total"] = total_loss
# Minkowski Engine recommendation
torch.cuda.empty_cache()
# Save checkpoint
if iteration % config.SOLVER.CHECKPOINT_PERIOD == 0:
checkpointer.save(f"model_{iteration:07d}", **checkpoint_arguments)
last_training_stage = model.set_current_training_stage(iteration)
# Save additional checkpoint after hierarchy level
if last_training_stage is not None:
checkpointer.save(f"model_{last_training_stage}_{iteration:07d}", **checkpoint_arguments)
logger.info(f"Finish {last_training_stage} hierarchy level")
iteration += 1
iteration_end = time.time()
print("\riteration: {}, total_loss: {}, rgb_loss:{}".format(iteration, total_loss, rgb_loss), end="")
if iteration%20 == 0:
print("\riteration: {}, total_loss: {}, rgb_loss:{}".format(iteration, total_loss, rgb_loss))
# if idx>4:
# break