Skip to content

Commit

Permalink
[example] Move
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Oct 25, 2015
1 parent 1e2bd62 commit 7bcbf24
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 2 deletions.
1 change: 0 additions & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Notebooks
* [composite symbol](notebooks/composite_symbol.ipynb) gives you a demo of how to composite a symbolic Inception-BatchNorm Network
* [cifar-10 recipe](notebooks/cifar-recipe.ipynb) gives you a step by step demo of how to use MXNet
* [cifar-100](notebooks/cifar-100.ipynb) gives you a demo of how to train a 75.68% accuracy CIFAR-100 model
* [predict with pretained model](notebooks/predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network
* [simple bind](notebooks/simple_bind.ipynb) gives you a demo of some details in ```mx.model``` module.

Contents
Expand Down
6 changes: 5 additions & 1 deletion example/imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ Note: A commonly mistake is forgetting shuffle the image list. This will lead fa

- [alexnet.py](alexnet.py) : alexnet with 5 convolution layers followed by 3
fully connnected layers
- [inception.py](inception.py) : inception + batch norm network
- [inception.py](inception.py) : inception + batch norm network for ImageNet with 1000 classes problem
- [inception-full.py](inception-full.py) : This inception network is used for ImageNet with 21841 classes

## Notebooks
- [predict with pretained model](predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network

## Results

Expand Down
101 changes: 101 additions & 0 deletions example/imagenet/inception-full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# pylint: skip-file
import sys
sys.path.insert(0, "../mxnet/python")
import mxnet as mx
import logging
from data import ilsvrc12_iterator


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = mx.symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))
act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat

def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name):
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name))
# concat
concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat

def inception(nhidden, grad_scale):
# data
data = mx.symbol.Variable(name="data")
# stage 1
conv1 = ConvFactory(data=data, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1')
pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max')
# stage 2
conv2red = ConvFactory(data=pool1, num_filter=128, kernel=(1, 1), stride=(1, 1), name='conv2red')
conv2 = ConvFactory(data=conv2red, num_filter=288, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2')
pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max')
# stage 2
in3a = InceptionFactoryA(pool2, 96, 96, 96, 96, 144, "avg", 48, '3a')
in3b = InceptionFactoryA(in3a, 96, 96, 144, 96, 144, "avg", 96, '3b')
in3c = InceptionFactoryB(in3b, 192, 240, 96, 144, '3c')
# stage 3
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a')
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b')
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c')
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 96, "avg", 128, '4d')
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e')
# stage 4
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a')
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b')
# global avg pooling
avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
# linear classifier
flatten = mx.symbol.Flatten(data=avg, name='flatten')
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=nhidden, name='fc1')
softmax = mx.symbol.Softmax(data=fc1, name='softmax')
return softmax

softmax = inception(21841, 1.0)

batch_size = 64
num_gpu = 4
gpus = [mx.gpu(i) for i in range(num_gpu)]
input_shape = (3, 224, 224)

train = ilsvrc12_iterator(batch_size=batch_size, input_shape=(3,224,224))

model_prefix = "model/Inception-Full"
num_round = 10

logging.info("This script is used to train ImageNet fullset over 21841 classes.")
logging.info("For noraml 1000 classes problem, please use inception.py")

model = mx.model.FeedForward(ctx=gpus, symbol=softmax, num_round=num_round,
learning_rate=0.05, momentum=0.9, wd=0.00001)

model.fit(X=train,
eval_metric="acc",
epoch_end_callback=[mx.callback.Speedometer(batch_size), mx.callback.log_train_metric(100)],
iter_end_callback=mx.callback.do_checkpoint(model_prefix))

0 comments on commit 7bcbf24

Please sign in to comment.