-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathIterativeSmooth.py
89 lines (74 loc) · 3.94 KB
/
IterativeSmooth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
import numpy as np
import os
import math
def gauss(t, r=0, window_size=3):
"""
@param window_size is the size of window over which gaussian to be applied
@param t is the index of current point
@param r is the index of point in window
@return guassian weights over a window size
"""
if np.abs(r-t) > window_size:
return 0
else:
return np.exp((-9*(r-t)**2)/window_size**2)
def generateSmooth(originPath, kernel=None, repeat=20):
# B, 1, T, H, W; B, 6, T, H, W
smooth = originPath
temp_smooth_3 = originPath[:, :, 3:-3, :, :]
kernel = kernel
if kernel is None:
kernel = torch.Tensor([gauss(i)
for i in range(-3, 4)]).to(originPath.device)
kernel = torch.cat([kernel[:3], kernel[4:]])
kernel = kernel.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)
kernel = kernel.repeat(*originPath.shape)
abskernel = torch.abs(kernel)
lambda_t = 100
for _ in range(repeat):
# import ipdb; ipdb.set_trace()
temp_smooth = torch.zeros_like(smooth, device=smooth.device)
temp_smooth_0 = smooth[:, :, 0:-6, :, :] * \
kernel[:, 0:1, 3:-3, :, :] * lambda_t
temp_smooth_1 = smooth[:, :, 1:-5, :, :] * \
kernel[:, 1:2, 3:-3, :, :] * lambda_t
temp_smooth_2 = smooth[:, :, 2:-4, :, :] * \
kernel[:, 2:3, 3:-3, :, :] * lambda_t
temp_smooth_4 = smooth[:, :, 4:-2, :, :] * \
kernel[:, 3:4, 3:-3, :, :] * lambda_t
temp_smooth_5 = smooth[:, :, 5:-1, :, :] * \
kernel[:, 4:5, 3:-3, :, :] * lambda_t
temp_smooth_6 = smooth[:, :, 6:, :, :] * \
kernel[:, 5:6, 3:-3, :, :] * lambda_t
temp_smooth[:, :, 3:-3, :, :] = ((temp_smooth_0 + temp_smooth_1 + temp_smooth_2 + temp_smooth_3 + temp_smooth_4 + temp_smooth_5 + temp_smooth_6)
/ (1 + lambda_t * torch.sum(abskernel[:, :, 3:-3, :, :], dim=1, keepdim=True)))
# 0
temp = smooth[:, :, 1:4, :, :]
temp_smooth[:, :, 0, :, :] = (torch.sum(kernel[:, 3:, 0, :, :].unsqueeze(
1) * temp, 2) * lambda_t + originPath[:, :, 0, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, 3:, 0, :, :].unsqueeze(1), 2))
# 1
temp = torch.cat([smooth[:, :, :1, :, :], smooth[:, :, 2:5, :, :]], 2)
temp_smooth[:, :, 1, :, :] = (torch.sum(kernel[:, 2:, 1, :, :].unsqueeze(
1) * temp, 2) * lambda_t + originPath[:, :, 1, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, 2:, 1, :, :].unsqueeze(1), 2))
# 2
temp = torch.cat([smooth[:, :, :2, :, :], smooth[:, :, 3:6, :, :]], 2)
temp_smooth[:, :, 2, :, :] = (torch.sum(kernel[:, 1:, 2, :, :].unsqueeze(
1) * temp, 2) * lambda_t + originPath[:, :, 2, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, 1:, 2, :, :].unsqueeze(1), 2))
# -1
temp = smooth[:, :, -4:-1]
temp_smooth[:, :, -1, :, :] = (torch.sum(kernel[:, :3, -1, :, :].unsqueeze(1) * temp, 2) * lambda_t +
originPath[:, :, -1, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, :3, -1, :, :].unsqueeze(1), 2))
# -2
temp = torch.cat([smooth[:, :, -5:-2, :, :],
smooth[:, :, -1:, :, :]], 2)
temp_smooth[:, :, -2, :, :] = (torch.sum(kernel[:, :4, -2, :, :].unsqueeze(1) * temp, 2) * lambda_t +
originPath[:, :, -2, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, :4, -2, :, :].unsqueeze(1), 2))
# -3
temp = torch.cat([smooth[:, :, -6:-3, :, :],
smooth[:, :, -2:, :, :]], 2)
temp_smooth[:, :, -3, :, :] = (torch.sum(kernel[:, :5, -3, :, :].unsqueeze(1) * temp, 2) * lambda_t +
originPath[:, :, -3, :, :]) / (1 + lambda_t * torch.sum(abskernel[:, :5, -3, :, :].unsqueeze(1), 2))
smooth = temp_smooth
return smooth