-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathefficientdet.py
292 lines (238 loc) · 12 KB
/
efficientdet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch.nn as nn
import torch
import math
from efficientnet_pytorch import EfficientNet as EffNet
from utils import BBoxTransform, ClipBoxes, Anchors
from loss_function import FocalLoss
from torchvision.ops.boxes import nms as nms_torch
def nms(dets, thresh):
return nms_torch(dets[:, :4], dets[:, 4], thresh)
class ConvBlock(nn.Module):
def __init__(self, num_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels),
nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(num_features=num_channels, momentum=0.9997, eps=4e-5), nn.ReLU())
def forward(self, input):
return self.conv(input)
class BiFPN(nn.Module):
def __init__(self, num_channels, epsilon=1e-4):
super(BiFPN, self).__init__()
self.epsilon = epsilon
# Conv layers
self.conv6_up = ConvBlock(num_channels)
self.conv5_up = ConvBlock(num_channels)
self.conv4_up = ConvBlock(num_channels)
self.conv3_up = ConvBlock(num_channels)
self.conv4_down = ConvBlock(num_channels)
self.conv5_down = ConvBlock(num_channels)
self.conv6_down = ConvBlock(num_channels)
self.conv7_down = ConvBlock(num_channels)
# Feature scaling layers
self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p4_downsample = nn.MaxPool2d(kernel_size=2)
self.p5_downsample = nn.MaxPool2d(kernel_size=2)
self.p6_downsample = nn.MaxPool2d(kernel_size=2)
self.p7_downsample = nn.MaxPool2d(kernel_size=2)
# Weight
self.p6_w1 = nn.Parameter(torch.ones(2))
self.p6_w1_relu = nn.ReLU()
self.p5_w1 = nn.Parameter(torch.ones(2))
self.p5_w1_relu = nn.ReLU()
self.p4_w1 = nn.Parameter(torch.ones(2))
self.p4_w1_relu = nn.ReLU()
self.p3_w1 = nn.Parameter(torch.ones(2))
self.p3_w1_relu = nn.ReLU()
self.p4_w2 = nn.Parameter(torch.ones(3))
self.p4_w2_relu = nn.ReLU()
self.p5_w2 = nn.Parameter(torch.ones(3))
self.p5_w2_relu = nn.ReLU()
self.p6_w2 = nn.Parameter(torch.ones(3))
self.p6_w2_relu = nn.ReLU()
self.p7_w2 = nn.Parameter(torch.ones(2))
self.p7_w2_relu = nn.ReLU()
def forward(self, inputs):
#model architecture
# P3_0, P4_0, P5_0, P6_0 and P7_0
p3_in, p4_in, p5_in, p6_in, p7_in = inputs
# P7_0 to P7_2
# Weights for P6_0 and P7_0 to P6_1
p6_w1 = self.p6_w1_relu(self.p6_w1)
weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))
# Weights for P5_0 and P6_0 to P5_1
p5_w1 = self.p5_w1_relu(self.p5_w1)
weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
# Connections for P5_0 and P6_0 to P5_1 respectively
p5_up = self.conv5_up(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))
# Weights for P4_0 and P5_0 to P4_1
p4_w1 = self.p4_w1_relu(self.p4_w1)
weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
# Connections for P4_0 and P5_0 to P4_1 respectively
p4_up = self.conv4_up(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))
# Weights for P3_0 and P4_1 to P3_2
p3_w1 = self.p3_w1_relu(self.p3_w1)
weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
# Connections for P3_0 and P4_1 to P3_2 respectively
p3_out = self.conv3_up(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))
# Weights for P4_0, P4_1 and P3_2 to P4_2
p4_w2 = self.p4_w2_relu(self.p4_w2)
weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
# Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
p4_out = self.conv4_down(
weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))
# Weights for P5_0, P5_1 and P4_2 to P5_2
p5_w2 = self.p5_w2_relu(self.p5_w2)
weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
# Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
p5_out = self.conv5_down(
weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))
# Weights for P6_0, P6_1 and P5_2 to P6_2
p6_w2 = self.p6_w2_relu(self.p6_w2)
weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
# Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
p6_out = self.conv6_down(
weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))
# Weights for P7_0 and P6_2 to P7_2
p7_w2 = self.p7_w2_relu(self.p7_w2)
weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
# Connections for P7_0 and P6_2 to P7_2
p7_out = self.conv7_down(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))
return p3_out, p4_out, p5_out, p6_out, p7_out
class Regressor(nn.Module):
def __init__(self, in_channels, num_anchors, num_layers):
super(Regressor, self).__init__()
layers = []
for _ in range(num_layers):
layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
layers.append(nn.ReLU(True))
self.layers = nn.Sequential(*layers)
self.header = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
def forward(self, inputs):
inputs = self.layers(inputs)
inputs = self.header(inputs)
output = inputs.permute(0, 2, 3, 1)
return output.contiguous().view(output.shape[0], -1, 4)
class Classifier(nn.Module):
def __init__(self, in_channels, num_anchors, num_classes, num_layers):
super(Classifier, self).__init__()
self.num_anchors = num_anchors
self.num_classes = num_classes
layers = []
for _ in range(num_layers):
layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
layers.append(nn.ReLU(True))
self.layers = nn.Sequential(*layers)
self.header = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
self.act = nn.Sigmoid()
def forward(self, inputs):
inputs = self.layers(inputs)
inputs = self.header(inputs)
inputs = self.act(inputs)
inputs = inputs.permute(0, 2, 3, 1)
output = inputs.contiguous().view(inputs.shape[0], inputs.shape[1], inputs.shape[2], self.num_anchors,
self.num_classes)
return output.contiguous().view(output.shape[0], -1, self.num_classes)
class EfficientNet(nn.Module):
def __init__(self, ):
super(EfficientNet, self).__init__()
model = EffNet.from_pretrained('efficientnet-b0')
del model._conv_head
del model._bn1
del model._avg_pooling
del model._dropout
del model._fc
self.model = model
def forward(self, x):
x = self.model._swish(self.model._bn0(self.model._conv_stem(x)))
feature_maps = []
for idx, block in enumerate(self.model._blocks):
drop_connect_rate = self.model._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.model._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if block._depthwise_conv.stride == [2, 2]:
feature_maps.append(x)
return feature_maps[1:]
class EfficientDet(nn.Module):
def __init__(self, num_anchors=9, num_classes=20, compound_coef=0):
super(EfficientDet, self).__init__()
self.compound_coef = compound_coef
self.num_channels = [64, 88, 112, 160, 224, 288, 384, 384][self.compound_coef]
self.conv3 = nn.Conv2d(40, self.num_channels, kernel_size=1, stride=1, padding=0)
self.conv4 = nn.Conv2d(80, self.num_channels, kernel_size=1, stride=1, padding=0)
self.conv5 = nn.Conv2d(192, self.num_channels, kernel_size=1, stride=1, padding=0)
self.conv6 = nn.Conv2d(192, self.num_channels, kernel_size=3, stride=2, padding=1)
self.conv7 = nn.Sequential(nn.ReLU(),
nn.Conv2d(self.num_channels, self.num_channels, kernel_size=3, stride=2, padding=1))
self.bifpn = nn.Sequential(*[BiFPN(self.num_channels) for _ in range(min(2 + self.compound_coef, 8))])
self.num_classes = num_classes
self.regressor = Regressor(in_channels=self.num_channels, num_anchors=num_anchors,
num_layers=3 + self.compound_coef // 3)
self.classifier = Classifier(in_channels=self.num_channels, num_anchors=num_anchors, num_classes=num_classes,
num_layers=3 + self.compound_coef // 3)
self.anchors = Anchors()
self.regressBoxes = BBoxTransform()
self.clipBoxes = ClipBoxes()
self.focalLoss = FocalLoss()
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
prior = 0.01
self.classifier.header.weight.data.fill_(0)
self.classifier.header.bias.data.fill_(-math.log((1.0 - prior) / prior))
self.regressor.header.weight.data.fill_(0)
self.regressor.header.bias.data.fill_(0)
self.backbone_net = EfficientNet()
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def forward(self, inputs):
if len(inputs) == 2:
is_training = True
img_batch, annotations = inputs
else:
is_training = False
img_batch = inputs
c3, c4, c5 = self.backbone_net(img_batch)
p3 = self.conv3(c3)
p4 = self.conv4(c4)
p5 = self.conv5(c5)
p6 = self.conv6(c5)
p7 = self.conv7(p6)
features = [p3, p4, p5, p6, p7]
features = self.bifpn(features)
regression = torch.cat([self.regressor(feature) for feature in features], dim=1)
classification = torch.cat([self.classifier(feature) for feature in features], dim=1)
anchors = self.anchors(img_batch)
if is_training:
return self.focalLoss(classification, regression, anchors, annotations)
else:
transformed_anchors = self.regressBoxes(anchors, regression)
transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)
scores = torch.max(classification, dim=2, keepdim=True)[0]
scores_over_thresh = (scores > 0.05)[0, :, 0]
if scores_over_thresh.sum() == 0:
return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]
classification = classification[:, scores_over_thresh, :]
transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
scores = scores[:, scores_over_thresh, :]
anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
if __name__ == '__main__':
# from tensorboardX import SummaryWriter
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = EfficientDet(num_classes=80)
print (count_parameters(model))