-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule.py
120 lines (102 loc) · 3.17 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn.functional as F
import torch.nn as nn
class Residual(nn.Module):
def __init__(self, sequential):
super().__init__()
self.sequential = sequential
def forward(self, x):
res = self.sequential(x)
try:
x = x + res
except:
diffY = x.size()[2] - res.size()[2]
diffX = x.size()[3] - res.size()[3]
res = F.pad(
res, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
x = x + res
return x
class ChannelShuffle(nn.Module):
def __init__(self, groups):
super().__init__()
self.groups = groups
def forward(self, x):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batchsize, self.groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class GhostModule(nn.Module):
def __init__(
self,
inp,
oup,
kernel=(1, 1),
dw_kernel=3,
act=True,
):
super(GhostModule, self).__init__()
self.oup = oup
init_channels = (oup + 1) // 2
self.primary_conv = nn.Sequential(
nn.Conv2d(
inp,
init_channels,
kernel,
padding=(kernel[0] // 2, kernel[1] // 2),
bias=False,
),
nn.InstanceNorm2d(init_channels),
nn.ReLU(inplace=True) if act else nn.Sequential(),
)
self.cheap_operation = nn.Sequential(
nn.Conv2d(
init_channels,
init_channels,
dw_kernel,
1,
dw_kernel // 2,
groups=init_channels,
bias=False,
),
nn.InstanceNorm2d(init_channels),
nn.ReLU(inplace=True) if act else nn.Sequential(),
)
def forward(self, x):
x1 = self.primary_conv(x)
x2 = self.cheap_operation(x1)
out = torch.cat([x1, x2], dim=1)
return out[:, : self.oup, :, :]
class ConvNormAct(nn.Module):
def __init__(self, ins, outs, kernel, stride=1, groups=1, act=True, bias=False):
super().__init__()
if type(kernel) == int:
kernel = (kernel, kernel)
self.conv = nn.Sequential(
nn.Conv2d(
ins,
outs,
kernel,
stride=stride,
padding=(kernel[0] // 2, kernel[1] // 2),
groups=groups,
bias=bias,
),
nn.InstanceNorm2d(outs),
nn.ReLU(inplace=True) if act else nn.Identity(),
)
def forward(self, x):
return self.conv(x)
class DWSeparableConv(nn.Module):
def __init__(self, ins, outs, kernel):
super().__init__()
self.conv = nn.Sequential(
ConvNormAct(ins, ins, kernel, groups=ins),
ConvNormAct(ins, outs, 1),
)
def forward(self, x):
return self.conv(x)