-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
107 lines (73 loc) · 3.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import time
import torch
import random
import logging
import matplotlib
import numpy as np
from matplotlib import pyplot as plt
from skimage.metrics import structural_similarity
matplotlib.use('agg')
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
def visualize(inputs, targets, outputs, epoch, idx, cache_dir):
_, axarray = plt.subplots(3, targets.shape[1], figsize=(targets.shape[1] * 5, 10))
for t in range(targets.shape[1]):
axarray[0][t].imshow(inputs[0, t, 0].detach().cpu().numpy(), cmap='gray')
axarray[1][t].imshow(targets[0, t, 0].detach().cpu().numpy(), cmap='gray')
axarray[2][t].imshow(outputs[0, t, 0].detach().cpu().numpy(), cmap='gray')
plt.savefig(os.path.join(cache_dir, '{:03d}-{:03d}.png'.format(epoch, idx)))
plt.close()
def plot_loss(loss_records, loss_type, epoch, plot_dir, step):
plt.plot(range((epoch + 1) // step), loss_records, label=loss_type)
plt.legend()
plt.savefig(os.path.join(plot_dir, '{}_loss_records.png'.format(loss_type)))
plt.close()
def MAE(pred, true):
return np.mean(np.abs(pred - true), axis=(0, 1)).sum()
def MSE(pred, true):
return np.mean((pred - true) ** 2, axis=(0, 1)).sum()
# cite the 'PSNR' code from E3D-LSTM, Thanks!
# https://github.com/google/e3d_lstm/blob/master/src/trainer.py line 39-40
def PSNR(pred, true):
mse = np.mean((np.uint8(pred * 255) - np.uint8(true * 255)) ** 2)
return 20 * np.log10(255) - 10 * np.log10(mse)
def compute_metrics(predictions, targets):
targets = targets.permute(0, 1, 3, 4, 2).detach().cpu().numpy()
predictions = predictions.permute(0, 1, 3, 4, 2).detach().cpu().numpy()
batch_size = predictions.shape[0]
Seq_len = predictions.shape[1]
ssim = 0
for batch in range(batch_size):
for frame in range(Seq_len):
ssim += structural_similarity(targets[batch, frame].squeeze(),
predictions[batch, frame].squeeze())
ssim /= (batch_size * Seq_len)
mse = MSE(predictions, targets)
return mse, ssim
def check_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def make_dir(args):
cache_dir = os.path.join(args.res_dir, 'cache')
check_dir(cache_dir)
model_dir = os.path.join(args.res_dir, 'model')
check_dir(model_dir)
log_dir = os.path.join(args.res_dir, 'log')
check_dir(log_dir)
return cache_dir, model_dir, log_dir
def init_logger(log_dir):
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M',
filename=os.path.join(log_dir, time.strftime("%Y_%m_%d") + '.log'),
filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
return logging