-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathattention_module.py
38 lines (28 loc) · 1.26 KB
/
attention_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from keras.layers import multiply, Permute, Concatenate, \
Conv2D, Lambda
from keras import backend as K
def spatial_attention(input_feature):
kernel_size = 7
if K.image_data_format() == "channels_first":
channel = input_feature._keras_shape[1]
cbam_feature = Permute((2, 3, 1))(input_feature)
else:
channel = input_feature._keras_shape[-1]
cbam_feature = input_feature
avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
assert avg_pool._keras_shape[-1] == 1
max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
assert max_pool._keras_shape[-1] == 1
concat = Concatenate(axis=3)([avg_pool, max_pool])
assert concat._keras_shape[-1] == 2
cbam_feature = Conv2D(filters=1,
kernel_size=kernel_size,
strides=1,
padding='same',
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=False)(concat)
assert cbam_feature._keras_shape[-1] == 1
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])