-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmade.py
128 lines (103 loc) · 4.02 KB
/
made.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Bugs: I have test (order-agnostic, connect-agnostic) combination in many ways but,
could not improve baseline model performance with any agnostic techniques.
"""
import random
import torch
from torch import nn
from torch.distributions import Bernoulli
class MaskLinear(nn.Linear):
def __init__(self, *args, mask=None, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer('mask', mask)
self.mask_index = 0
def apply(self, idx):
self.mask_index = idx
def forward(self, x):
self.weight.data *= self.mask[self.mask_index]
return super().forward(x)
class MADE(nn.Module):
SEED = 1000
def __init__(self, d=784, h=8000, l=1, n_mask=1):
super().__init__()
assert n_mask == 1, f"The number of mask should be 1, not {n_mask}."
self.d = d
self.h = h
self.l = l + 1
self.n_mask = n_mask
self.dims = [d] + [h] * l + [d]
self.masks, self.orders = self.generate_mask()
self.layers = nn.ModuleList([
MaskLinear(self.dims[i], self.dims[i+1], mask=self.masks[i]) for i in range(len(self.dims)-1)
])
self.act = nn.ReLU(inplace=True)
def generate_mask(self):
orders = list()
masks = [torch.zeros(self.n_mask, self.dims[i+1], self.dims[i]) for i in range(self.l)]
for mask_idx in range(self.n_mask):
# 1. fix seed for reproducibility
g = torch.Generator().manual_seed(self.SEED + mask_idx)
# 2. generate random connection
last_order = first_order = torch.randperm(self.d, generator=g)
connections = [first_order]
for i, dim in enumerate(self.dims[1:-1]):
low = min(connections[i]) if i > 0 else 0
connections.append(torch.randint(low, self.d-1, (dim,), generator=g))
connections.append(last_order)
# 3. generate mask
for layer_idx in range(self.l-1):
masks[layer_idx][mask_idx] = connections[layer_idx][None, :] <= connections[layer_idx+1][:, None]
masks[-1][mask_idx] = connections[-2][None, :] < connections[-1][:, None]
# 4. append order
orders.append(connections[0].argsort())
return masks, orders
def choose_mask(self, idx):
for layer in self.layers:
layer.apply(idx)
def forward_once(self, x):
shape = x.shape
x = x.reshape(shape[0], -1)
for i, layer in enumerate(self.layers):
x = layer(x)
if i != self.l - 1:
x = self.act(x)
x = x.reshape(shape)
p = torch.sigmoid(x)
sample = Bernoulli(probs=p).sample()
return x, sample
def forward(self, x):
if self.train():
mask_idx = random.randint(0, self.n_mask - 1)
self.choose_mask(mask_idx)
return self.forward_once(x)
else:
logits, samples = list(), list()
for mask_idx in range(self.n_mask):
self.choose_mask(mask_idx)
logit, sample = self.forward_once(x)
logits.append(logit)
samples.append(sample)
logit = sum(logits) / len(logits)
sample = sum(samples) / len(samples)
return logit, sample
@torch.no_grad()
def sample(self, shape, device, *args, **kwargs):
B, C, H, W = shape
mask_idx = random.randint(0, self.n_mask - 1)
x = torch.full(shape, -1).to(torch.float).to(device)
self.choose_mask(mask_idx)
order = self.orders[mask_idx]
i = 0
for h in range(H):
for w in range(W):
_, sample = self.forward_once(x)
x[:, :, order[i] // W, order[i] % W] = sample.reshape(B, -1)[:, order[i]:order[i]+1]
i += 1
return x
if __name__ == '__main__':
x = torch.rand(3, 1, 28, 28)
f = MADE()
logit, sample = f(x)
generated = f.sample((3, 1, 28, 28), 'cpu')
print(logit.shape)
print(generated.shape)