-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSOM.py
85 lines (68 loc) · 2.91 KB
/
SOM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from torchvision.utils import save_image
class SelfOrganizingMap(nn.Module):
def __init__(self, x_size, out_size=(10, 10), lr=0.3, sigma=None):
super(SelfOrganizingMap, self).__init__()
self.x_size = x_size
self.out_size = out_size
self.lr = lr
if sigma is None:
self.sigma = max(out_size) / 2
else:
self.sigma = float(sigma)
self.weight = nn.Parameter(torch.randn(x_size, out_size[0] * out_size[1]), requires_grad=False)
self.locations = nn.Parameter(torch.Tensor(list(self.get_map_index())), requires_grad=False)
self.pdist_fn = nn.PairwiseDistance(p=2)
def get_map_index(self):
'''Two-dimensional mapping function'''
for x in range(self.out_size[0]):
for y in range(self.out_size[1]):
yield (x, y)
def _neighborhood_fn(self, x, current_sigma):
'''e^(-(x / sigma^2))'''
x = x.div(current_sigma ** 2).neg().exp()
return x
def forward(self, x):
'''
Find the location of best matching unit.
:param x: data
:return: location of best matching unit, loss
'''
batch_size = x.size()[0]
x = x.view(batch_size, -1, 1)
batch_weight = self.weight.expand(batch_size, -1, -1)
dists = self.pdist_fn(x, batch_weight)
# Find best matching unit
losses, bmu_indexes = dists.min(dim=1, keepdim=True)
bmu_locations = self.locations[bmu_indexes]
return bmu_locations, losses.sum().div_(batch_size).item()
def self_organizing(self, x, current_iter, max_iter):
'''
Train the Self Oranizing Map(SOM) with stochastic gradient descent
:param x: training data
:param current_iter: current epoch of total epoch
:param max_iter: total epoch
:return: loss (minimum distance)
'''
batch_size = x.size()[0]
#Set learning rate
iter_correction = 1.0 - current_iter / max_iter
lr = self.lr * iter_correction
sigma = self.sigma * iter_correction
#Find best matching unit
bmu_locations, loss = self.forward(x)
distance_squares = self.locations.float() - bmu_locations.float()
distance_squares.pow_(2)
distance_squares = torch.sum(distance_squares, dim=2)
lr_locations = self._neighborhood_fn(distance_squares, sigma)
lr_locations.mul_(lr).unsqueeze_(1)
delta = lr_locations * (x.unsqueeze(2) - self.weight)
delta = delta.sum(dim=0)
delta.div_(batch_size)
self.weight.data.add_(delta)
return loss
def save_result(self, dir, im_size=(0, 0, 0)):
images = self.weight.view(im_size[0], im_size[1], im_size[2], self.out_size[0] * self.out_size[1])
images = images.permute(3, 0, 1, 2)
save_image(images, dir, normalize=True, padding=1, nrow=self.out_size[0])