forked from zorzi-s/projectRegularization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining_utils.py
82 lines (64 loc) · 2.27 KB
/
training_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import cv2
import glob
from tqdm import tqdm
import random
from skimage import io
from skimage.segmentation import mark_boundaries
import random
import time
import datetime
import sys
from torch.autograd import Variable
import torch
import numpy as np
import gdal
import variables as var
def sample_images(sample_index, img, masks):
batch = img.shape[0]
img = img.permute(0,2,3,1)
for i in range(len(masks)):
masks[i] = masks[i].permute(0,2,3,1)
img = img.cpu().numpy()
ip = np.uint8(img * 255)
for i in range(len(masks)):
masks[i] = masks[i].detach().cpu().numpy()
masks[i] = np.argmax(masks[i], axis=-1)
masks[i] = np.uint8(masks[i] * 255)
line_mode = "inner"
for i in range(len(masks)):
row = np.copy(ip[0,:,:,:])
line = cv2.Canny(masks[i][0,:,:], 0, 255)
row = mark_boundaries(row, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0))
for b in range(1,batch):
pic = np.copy(ip[b,:,:,:])
line = cv2.Canny(masks[i][b,:,:], 0, 255)
pic = mark_boundaries(pic, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0))
row = np.concatenate((row, pic), 1)
masks[i] = row
img = np.concatenate(masks, 0)
img = np.uint8(img)
io.imsave(var.DEBUG_DIR + "debug_%s.png" % str(sample_index), img)
class LossBuffer():
def __init__(self, max_size=100):
self.data = []
self.max_size = max_size
def push(self, data):
self.data.append(data)
if len(self.data) > self.max_size:
self.data = self.data[1:]
return sum(self.data) / len(self.data)
class LambdaLR():
def __init__(self, n_batches, decay_start_batch):
assert ((n_batches - decay_start_batch) > 0), "Decay must start before the training session ends!"
self.n_batches = n_batches
self.decay_start_batch = decay_start_batch
def step(self, batch):
if batch > self.decay_start_batch:
factor = 1.0 - (batch - self.decay_start_batch) / (self.n_batches - self.decay_start_batch)
if factor > 0:
return factor
else:
return 0.0
else:
return 1.0