-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtraining_realBall.py
149 lines (112 loc) · 4.35 KB
/
training_realBall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import torch
import hydra
import logging
from omegaconf import DictConfig
from models.sceneRepresentation import Scene
from dataset.dataset import ImageDataset_realData, Dataloader
from optimization.optimizers import optimizersScene
from optimization.loss import Losses
from util.initialValues import estimate_initial_vals_ball
from util.visualization import VisualizationRealBall
from util.util import setSeeds
log = logging.getLogger(__name__)
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
@hydra.main(config_path=CONFIG_DIR, config_name="realBall")
def main(cfg: DictConfig):
# Select device
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Seed
setSeeds(cfg.seed)
train_data = ImageDataset_realData(
**cfg.data,
max_samples=cfg['data']['samples_train']
)
test_data = ImageDataset_realData(
**cfg.data,
indices=train_data.unused_inds
)
train_dataloader = Dataloader(train_data, **cfg.dataloader)
tspan = train_data.t_steps.to(device)
tspan_test = test_data.t_steps.to(device)
log.info("Done loading data")
# Initialize model
model = Scene(**cfg.scene.background)
init_values_estimate = estimate_initial_vals_ball(
train_data.get_full_mask(),
train_data.get_pixel_coords(),
tspan.cpu()
)
# Seed again to ensure consistent initialization
# (different architectures before will change the seed at this point)
setSeeds(cfg.seed)
model.add_thrownObject(
**init_values_estimate,
**cfg.scene.local_representation
)
# Move to device
model.to(device)
# Initialize the optimizers
optimizers = optimizersScene(model, cfg.optimizer)
if cfg.logging.enable:
visualization = VisualizationRealBall()
model.train()
# Initialize loss
losses = Losses(cfg.loss)
losses.regularize_mask = False
# Seed again to ensure consistent initialization
# (different architectures before will change the seed at this point)
setSeeds(cfg.seed)
# Run trainingsloop
log.info("Start Trainings Loop")
for epoch in range(0, cfg.optimizer.epochs):
losses.zero_losses()
if epoch == cfg['loss']['regularize_after_epochs']:
losses.regularize_mask = True
if cfg.homography.enable and epoch == cfg.homography.enable_after_epochs:
model.use_homography = True
for data in train_dataloader:
# Read data
coords = data["coords"].to(device)
colors_gt = data["im_vals"].to(device)
# Adjust number of frames used (online training)
if cfg.online_training.enable:
l_tInterval = min(epoch // cfg.online_training.stepsize + cfg.online_training.start_length, len(tspan))
tspan_cur = tspan[:l_tInterval]
colors_gt = colors_gt[:, :l_tInterval, :]
else:
tspan_cur = tspan
# Zero gradient
optimizers.zero_grad()
model.update_trafo(tspan_cur)
# Do trainings step
output = model(coords)
losses.compute_losses(output, colors_gt)
losses.backward()
optimizers.optimizer_step()
# Adjust learning rates
optimizers.lr_scheduler_step()
# Write to tensorboard
if epoch % cfg.logging.logging_interval == 0 and cfg.logging.enable:
visualization.log_scalars(epoch, losses, model, l_tInterval, train_data)
log.info("Scalars logged. Epoch: " + str(epoch))
# Render test frames
if epoch % cfg.logging.test_interval == 0 and cfg.logging.enable:
model.eval()
visualization.render_test(epoch, model, tspan, train_data, tspan_test, test_data)
model.train()
log.info("Rendering test done. Epoch: " + str(epoch))
# Checkpoint
if (
epoch % cfg.logging.checkpoint_interval == 0
and epoch > 0
):
log.info("Storing checkpoint. Epoch " + str(epoch))
torch.save(model.state_dict(), os.path.join(os.path.abspath(''), 'ckpt.pth'))
log.info("Storing final checkpoint. Epoch " + str(epoch))
torch.save(model.state_dict(), os.path.join(os.path.abspath(''), 'ckpt.pth'))
if __name__ == "__main__":
main()