forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsymbol_dsb.py
32 lines (29 loc) · 1.53 KB
/
symbol_dsb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import find_mxnet
import mxnet as mx
def get_symbol(num_classes = 121):
net = mx.sym.Variable("data")
net = mx.sym.Convolution(data=net, kernel=(5, 5), num_filter=32, pad=(2, 2))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Convolution(data=net, kernel=(5, 5), num_filter=64, pad=(2, 2))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2))
# stage 2
net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=64, pad=(1, 1))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=64, pad=(1, 1))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=128, pad=(1, 1))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2))
# stage 3
net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=256, pad=(1, 1))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=256, pad=(1, 1))
net = mx.sym.Activation(data=net, act_type="relu")
net = mx.sym.Pooling(data=net, pool_type="avg", kernel=(9, 9), stride=(1, 1))
# stage 4
net = mx.sym.Flatten(data=net)
net = mx.sym.Dropout(data=net, p=0.25)
net = mx.sym.FullyConnected(data=net, num_hidden=121)
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
return net