-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyolo_nn_builders.py
38 lines (33 loc) · 1.36 KB
/
yolo_nn_builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch import nn
from torchvision.models import resnet34, resnet18
from torchvision.models import ResNet18_Weights
class YOLOResNetFullyConnected(nn.Module):
"""YOLOv1-Resnet model structure
yolo-v1 resnet = resnet(backbone) + conv + fc
"""
def __init__(self, S, B, num_classes):
super(YOLOResNetFullyConnected, self).__init__()
self.S = S
self.B = B
self.num_classes = num_classes
self.resnet = resnet18(weights = ResNet18_Weights.IMAGENET1K_V1)
#self.resnet = resnet34()
#Esto estaria bueno tenerlo en un module aparte
# backbone part, (cut resnet's last two layers)
self.backbone = nn.Sequential(*list(self.resnet.children())[:-2])
#We freeze the resnet for training
for param in self.backbone.parameters():
param.requires_grad=False
# full connection part
self.fc_layers = nn.Sequential(
nn.Linear(14 * 14 * 512, 4096),
nn.LeakyReLU(0.1),#, inplace=True),
nn.Linear(4096, self.S * self.S * (self.B * 5 + self.num_classes)),
nn.Sigmoid() # normalized to 0~1
)
def forward(self, x):
out = self.backbone(x)
out = out.view(out.size()[0], -1)
out = self.fc_layers(out)
out = out.reshape(-1, self.S, self.S, self.B * 5 + self.num_classes)
return out