-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathECAAttention3d.py
44 lines (36 loc) · 1.34 KB
/
ECAAttention3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict
class ECAAttention(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.gap=nn.AdaptiveAvgPool3d(1)
self.conv=nn.Conv1d(1,1,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.sigmoid=nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
y=self.gap(x) #bs,c,1,1,1
y=y.squeeze(-1).squeeze(-1).permute(0,2,1) #bs,1,c
y=self.conv(y) #bs,1,c
y=self.sigmoid(y) #bs,1,c
y=y.permute(0,2,1).unsqueeze(-1).unsqueeze(-1) #bs,c,1,1,1
return x*y.expand_as(x)
if __name__ == '__main__':
input=torch.randn(1,64,32,20,24)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)