-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGPA_module.py
359 lines (309 loc) · 15.9 KB
/
GPA_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
if False:#StyleGAN type convolution design
class EC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super().__init__()
self.conv=nn.Conv2d(in_channels, out_channels,kernel_size,stride=stride,padding=padding,bias=False)
self.bias=nn.Parameter(torch.zeros(1,out_channels,1,1))
fan_in = self.conv.weight.data.size(1) * self.conv.weight.data[0][0].numel()
self.mult = np.sqrt(2/ fan_in)
def forward(self,x):
return self.mult*(self.conv(x) + self.bias)
else:
EC = nn.Conv2d
#calculate negative distance matrix
#q is query, BxNxC
#k is key, BxCxN
def mdista3(q,k):
x= torch.bmm(2*q, k)
x.sub_( torch.sum(q ** 2, 2).unsqueeze(2))
x.sub_(torch.sum(k ** 2, 1).unsqueeze(1))
return x
#calculate full attention
# v is value, BxCxN
# q is query, BxNqxCq
# k is key, BxCxN
# output is BxCxNq
def batt(q ,k ,v ,euclid,onlyA=False,mask=None):
if euclid : # negative E distance
s = q.shape[2] ** 0.25
energy = mdista3 (q/s ,k/s)
else:
energy = torch.bmm(q/np.sqrt(q.shape[2]) ,k )#BNC x BCN
attention = F.softmax(energy-energy.detach().max(), dim=-1) # Bx N x N
if mask is not None:
attention=attention*mask
attention=attention/(1e-20+attention.sum(-1,keepdim=True))#renormalize
if onlyA:
return attention
out = torch.bmm(v, attention.permute(0, 2, 1))
return out
def batt_head(q,k,v,euclid, nhead,causal=False):
if nhead==1:
mask=None
if causal:
qi=torch.arange(q.shape[1],device=q.device).view(1,-1)
kj = torch.arange(k.shape[2],device=k.device).view(1, -1)
mask=CMask(qi,kj)
return batt(q,k,v,euclid=euclid,mask=mask)
batch = k.shape[0]
Dk = k.shape[1] // nhead
Dv = v.shape[1] // nhead
q = torch.cat(q.split(Dk, 2), 0)
k = torch.cat(k.split(Dk, 1), 0)
v = torch.cat(v.split(Dv, 1), 0)
mask=None
if causal:
qi=torch.arange(q.shape[1],device=q.device).view(1,-1)
kj = torch.arange(k.shape[2],device=k.device).view(1, -1)
mask=CMask(qi,kj)
return torch.cat(batt(q,k,v,euclid=euclid,mask=mask).split(batch, 0), 1)
#@param A is an attention matrix, size BxqNxkN, where B is batch size, qN is count of query spatial positons, kN is count of key spatial positions
#@param aspect is size of query tensor, ratio of w=shape[2] and h=shape[3] of images -- also supports float value
#@param kappa tunes sizes of relevant keysets
def GPA_phase1(A, aspect=2, splitM=16,kappa=3):
B=A.shape[0]#batch size
N=A.shape[1]#query dimension
idx = torch.arange(N,device=A.device)#full index structure
h = int(np.sqrt(N // aspect))
idx = idx.view(1,1,int(aspect*h),h)#spatially reshaped full index structure
s= min(h,int(aspect*h))//splitM#cell size -- total of K=splitM*splitM*aspect cells ; s = sqrt(N/K)
grid = F.unfold(idx.float(),kernel_size=(s,s),stride=(s,s)).long()#1xs*s x K , where K is count of blocks, s*s*K=N
grid = grid.permute(0,2,1) ## 1 x K x N/K , values are indices of tensor
K = grid.shape[1]#count of partition cells
allk = A.permute(0,2,1).view(B,-1,int(aspect*h),h)#keys as channels, query spatial
allk = F.unfold(allk,kernel_size=(s,s),stride=(s,s))##B x (Nk*s*s) x K
allk = allk.view(B,A.shape[2],s*s,K)##so all keys for the query grids
topk = allk.mean(2).permute(0,2,1).argsort(dim=2, descending=True)[:, :, :int(kappa )]
out={"grid":grid,"topk":topk,"s":s,"size":(int(aspect*h),h),"aspect":aspect,"hsize":None,"Nk":A.shape[2]}
return out
# calculate attention by using partitions of keys and queries
# @param partition dictionary
# @param queries, keys, values are size BxCxN
def GPA_phase2(queries, keys, values, partition, euclid,causal=False):
B = queries.shape[0] # batchsize
N = partition["grid"].shape[1] * partition["grid"].shape[2] # N of small spatial information of q tensor
K = partition["grid"].shape[1] # count partitions
Cq=queries.shape[1]
try:
Nq=queries.shape[2]*queries.shape[3]
S = int(np.sqrt(Nq / N)) # scale change, in height or width of pixels
s = partition["s"] * S # local grid cell size in large spatial tensor
Q = F.unfold(queries, kernel_size=(s, s), stride=(s, s))#b c*s*s K
except Exception as e:
print (e)
print (N,Nq,partition)
print (queries.shape,partition["size"][0] * S, partition["size"][1] * S,\
partition["size"][0] * S* partition["size"][1] * S,"ss",S,s)
raise Exception
Q=Q.view(B,Cq,s*s,K).permute(0,1,3,2)#B C K s*s
topk = up4d(partition["topk"], S,aspect=partition["aspect"],N=partition["Nk"],hsize=partition["hsize"])
#topk is size BxKx upsampled relevant keyset elements
# now move the K partitions as batches
#print ("queries",queries.shape,"Q",Q.shape,"stats",Cq,K,Nq//K)
Q = Q.view(B,Cq,K,Nq//K).permute(0,2,1,3).contiguous().view(B*K,Q.shape[1],Nq//K)
Nk = topk.shape[1] * topk.shape[2]##the keys chosen for all query partition cells, flattened partitions
topv = topk.view(B, 1, Nk).expand(B, values.shape[1], Nk)#copy same for all channels
topk = topk.view(B, 1, Nk).expand(B,queries.shape[1],Nk)
#print ("gather k sanity",keys.shape,values.shape,topk.max(),topk.min())
KEY = torch.gather(keys,dim=2,index=topk)#keys for each partition
V = torch.gather(values, dim=2, index=topv)
KEY = KEY.view(B, KEY.shape[1], K, Nk // K).permute(0, 2, 1, 3).contiguous().view(B * K, KEY.shape[1], Nk // K)
V = V.view(B, V.shape[1], K, Nk // K).permute(0, 2, 1, 3).contiguous().view(B * K, V.shape[1], Nk // K)
batch_att = batt(Q.permute(0, 2, 1), KEY, V, euclid=euclid)
att = batch_att.view(B,K,values.shape[1],-1).permute(0,2,3,1).view(B,-1,K)##Bx Cx N//Kx K -> B x C*N/K xK
##batch_att is in the partition indices order (unfolded)- fold to give back to square image
return F.fold(att, output_size=(partition["size"][0] * S, partition["size"][1] * S), kernel_size=(s , s ), stride=(s , s ))
# @param idx are indices of relevant keys
# @param S is ratio btw small and large tensor
# @param N is total number spatial positions in small tensor
# @param aspect ratio is aspect ratio of width and height of image tensors, e.g. aspect=2 means that tensors are size h*aspect x h = N
# @param hsize alternatively give directly pixel size horizontal, allows different aspect ratios of query and key
# @param out is concatenated in last dimension -- set of spatial indexes
def up4d(idx, S, N=None, aspect=1,hsize=None):
if hsize is None:
h = int(np.sqrt(N // aspect))
else:
h=hsize
xw = S * (idx // h)
xh = S * (idx % h) # get 2 coords from idx; also multiply by S since we have larger grid
buf = []
for dw in range(S):
for dh in range(S):
buf.append((xw + dw) * S * h + xh + dh)
idx2 = torch.cat(buf,dim=-1)
return idx2 # S**2 more entries than idx
#@param c is channels of val and key
#@param c2 is channels of query
#@param m is output channels desired, can be <= c2
#@param euclid controls whether we use dot product, or euclidean vector distance for affinity matrix
#@param nhead, nheadlow are heads for attention layers
#@param outmod controls how we fuse residually attention with may features; 1 concatenates channels and used 1x1 conv, 0 adds residual directly
#@param SHIERARCHY is size of tensor, above which we use approximate and not full attention
#@param kappa is number of relevant keys
#@param splitM is square root of partitioning granularity splitM x splitM
class GPA_module(nn.Module):
def __init__(self, c,c2,m,euclid=True,nhead=1,nheadlow=1,outmod=1,SHIERARCHY = 64,kappa=2,splitM=32):
super().__init__()
C=EC#styleGAN 2 inspired
n=max(c,c2)
self.n=n
self.m=m
def net(c,m):
return C(in_channels=c, out_channels=m,kernel_size=1)
self.VAL = C(in_channels=c, out_channels=m,kernel_size=1)##value effectively
self.KEY = net(c, n)
self.Q = net(c2, n)#C(in_channels=c2, out_channels=n, kernel_size=1)#query
self.outmod=outmod
if outmod==0:
self.gamma = nn.Parameter(torch.zeros(1) + 0.1)
elif outmod==1:
self.concatproj=C(c2+ m,m,kernel_size=1)#,3,padding=1)
else:
raise Exception('wrong param')
self.euclid=euclid
self.nhead=nhead
self.nheadlow=nheadlow
self.SHIERARCHY = SHIERARCHY
self.kappa = kappa
self.splitM=splitM
#wrapper for multihead GPA form
def _doAtt_head(self,Q,KEY,VAL,fac,aspect,psize,nhead):
if nhead==1:
return self._doAtt(Q,KEY,VAL,fac,aspect,psize)
batch = KEY.shape[0]
Dk = KEY.shape[1] // nhead
Dv = VAL.shape[1] // nhead
Q = torch.cat(Q.split(Dk, 1), 0)
KEY = torch.cat(KEY.split(Dk, 1), 0)
VAL = torch.cat(VAL.split(Dv, 1), 0)
return torch.cat(self._doAtt(Q,KEY,VAL,fac,aspect,psize,chSplit=nhead).split(batch, 0), 1)
#both phases of single head GPA
#@param Q,KEY,VAL are 4d tensors, BxCxHxW
#@param d is factor with which to downsample
#@param aspect is aspect ratio of Q, use if rectangular and not square
#@param psize is W dimension of KEY tensor -- use if rectangular and not square
#@param chSplit just technical detail to tweak channel counts for mheads
def _doAtt(self,Q,KEY,VAL,d,aspect,psize,chSplit=1):
m_batchsize=Q.shape[0]
with torch.no_grad(): ##no grad w.r.t. indexing variable
Qdown=F.avg_pool2d(Q, d).view(m_batchsize, self.n//chSplit, -1).permute(0, 2, 1)
Kdown = F.avg_pool2d(KEY, d).view(m_batchsize, self.n//chSplit, -1)
A = batt(Qdown,Kdown, None, onlyA=True, euclid=self.euclid)#full attention at lower level
partition = GPA_phase1(A, aspect=aspect, kappa=self.kappa,splitM=self.splitM)
partition["hsize"] = psize # hack to allow different aspect ratio key tensors
VAL = VAL.view(m_batchsize, self.m//chSplit, -1) ##batch x m x hw
KEY = KEY.view(m_batchsize, self.n//chSplit, -1)
att = GPA_phase2(Q, KEY, VAL, partition, euclid=self.euclid)
return att
#show stats of GPA approximation, percentage of relevant key affinity in total attention keys
def err_stat(self,xKeyValue, xQuery):
#print ("errstat call",xQuery.shape)
fac = xQuery.shape[2] // self.SHIERARCHY
if fac <=1:#so full attention used here
return 1
aspect = xQuery.shape[2] / float(xQuery.shape[3])
nhead=self.nhead
psize = xKeyValue.shape[3] // fac
Q = self.Q(xQuery)
KEY = self.KEY(xKeyValue)
m_batchsize = Q.shape[0]
if nhead>1:
Dk = KEY.shape[1] // nhead
Q = torch.cat(Q.split(Dk, 1), 0)
KEY = torch.cat(KEY.split(Dk, 1), 0)
m_batchsize*=nhead#batch size increased in this case
Qdown = F.avg_pool2d(Q, fac).view(m_batchsize, self.n // nhead, -1).permute(0, 2, 1)
Kdown = F.avg_pool2d(KEY, fac).view(m_batchsize, self.n // nhead, -1)
A = batt(Qdown, Kdown, None, onlyA=True, euclid=self.euclid)
partition = GPA_phase1(A, aspect=aspect, kappa=self.kappa,splitM=self.splitM)
partition["hsize"] = psize
_,_,out3=analyzeGPA(queries=Q.view(m_batchsize,self.n // nhead,-1), keys=KEY.view(m_batchsize,self.n // nhead,-1), A=partition, euclid=self.euclid)
x=np.array(out3)
print ("GPA_stat px %s affinity of relevant keys %0.4f"%(str(xQuery.shape),x.sum(1).mean()))
return x.sum(1).mean()
#@param xQuery is query tensor
#@param xKeyValue is conditioning information
def forward(self, xKeyValue, xQuery):
fac = xQuery.shape[2] // self.SHIERARCHY # make dependent on query dimensions
aspect = xQuery.shape[2] / float(xQuery.shape[3])#query image aspect ratio
m_batchsize, _, width, height = xQuery.size()
Q = self.Q(xQuery) #query 1x1 comv
KEY = self.KEY(xKeyValue) #key
VAL = self.VAL(xKeyValue) # value
if fac <=1:##if small -- directly use batt_head
VAL = VAL.view(m_batchsize, self.m, -1) #flatten #batch x m x hw
Q = Q.view(m_batchsize, self.n, -1) ##batch x n x hw
KEY = KEY.view(m_batchsize, self.n, -1)
att = batt_head(Q.permute(0,2,1),KEY,VAL,nhead=self.nheadlow,euclid=self.euclid,causal=self.causal).view(m_batchsize,-1, width, height)
else:
att = self._doAtt_head(Q, KEY, VAL, fac, aspect,psize=xKeyValue.shape[3] // fac, nhead=self.nhead)
if self.outmod==1:
return xQuery[:, :self.m] + self.concatproj(torch.cat([xQuery, att], 1))
return xQuery[:, :self.m] + self.gamma * att#default is 0
#print some statistics for the GPA approximation
# input tensors of size BxCxN - queries and keys flattened to one spatial dimension
#A is dictionary with the GPA_phase1 information
def analyzeGPA(queries, keys, A, euclid):
# q is BNC, k is BCN
def att(q, k, euclid):
return batt(q, k, None, euclid=euclid, onlyA=True)
out1 = []
out2 = []
out3 = []
B = queries.shape[0] # batchsize
Nq = queries.shape[2] # grid.shape[1]*grid.shape[2]
N = A["grid"].shape[1] * A["grid"].shape[2] # N of small spatial information of q tensor
S = int(np.sqrt(Nq / N)) # scale change, in height or width of pixels
K = A["grid"].shape[1] # count parttions
s = A["s"] * S
Q = F.unfold(queries.view(B, queries.shape[1], A["size"][0] * S, A["size"][1] * S), kernel_size=(s, s),
stride=(s, s)) # b c*s*s K
Q = Q.view(B, queries.shape[1], s * s, K).permute(0, 1, 3, 2) # B C K s*s
topk = up4d(A["topk"], S, N, aspect=A["aspect"])
##sample random batch and partition and spatial index -- any element from Q
##calc full attnetion
##calc also partial attention
for b in range(B):
for k in range(K):
qe = Q[b, :, k, np.random.randint(s * s)].view(1, 1, -1) # random index inside s*s cell #or setto 0 for first index
ke_all = keys[b:b + 1, :, :]
idx = topk[b, k, :] ##all indices of partition keys
ke = keys[b:b + 1, :, idx]
a_all = att(qe, ke_all, euclid).squeeze() # size 1x1xNk -- affinity of query with all keys
a_top = att(qe, ke, euclid).squeeze() # size 1x1x (N/K) -- affinity of query just with relevant subset keys
sub = a_all[idx]
if b==0 and k==0 and False:
print ("Analyzing GPA: all keys",a_all.shape,a_top.shape,a_top.sum().item(),"qk down level",qe.shape,ke.shape,ke_all.shape,"up level",keys.shape)
out1.append(a_all.cpu().numpy())
out2.append(a_top.cpu().numpy())
out3.append(sub.cpu().numpy())
return out1, out2, out3
if __name__ == "__main__":
c=128 # test with so many channel query, key, value
g2 = GPA_module(c, c, c, kappa=2).cuda()
g4 = GPA_module(c, c, c, kappa=4).cuda()
#test on a random noise tensor
embed = nn.Sequential(EC(3,c,1),nn.ReLU()).cuda()#random filters
x=torch.zeros((1,3,256,256)).cuda().uniform_(-1,1)
x=embed(x)
##print statistics -- how well do the relevant keys approximate the full attention
##the larger kappa gets, the more accurate the approximation of GPA is
with torch.no_grad():
g2.err_stat(x,x)
g4.err_stat(x, x)
#test on a natural image, random image from the DeepFashion dataset
from PIL import Image
im = Image.open('DFimage.png')
x=np.array(im)/255.0*2-1
x=torch.FloatTensor(x).cuda().unsqueeze(0).permute(0,3,1,2)
x = embed(x)
##test GPA approximation -- since natural images fit better assumptions A1,A2,A3 in the paper, the approximation is typically better
with torch.no_grad():
g2.err_stat(x,x)
g4.err_stat(x, x)
##typical usage of the GPA module
out = g2(x, x)