Skip to content

Commit

Permalink
Add ShuffleNet
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Sep 15, 2019
1 parent 7685f3c commit 2ec7686
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
6 changes: 5 additions & 1 deletion cnn/imagenet/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from cnn.mobilenet_imagenet import MobileNet
from cnn.mobilenet_imagenet import Butterfly1x1Conv
from cnn.shufflenet_imagenet import ShuffleNet

try:
from apex.parallel import DistributedDataParallel as DDP
Expand All @@ -28,7 +29,8 @@

class ModelAndLoss(nn.Module):
def __init__(self, arch, loss, pretrained_weights=None, cuda=True, fp16=False,
width=1.0, n_struct_layers=0, struct='D', softmax_struct='D', sm_pooling=1):
width=1.0, n_struct_layers=0, struct='D', softmax_struct='D', sm_pooling=1,
groups=8):
super(ModelAndLoss, self).__init__()
self.arch = arch

Expand All @@ -39,6 +41,8 @@ def __init__(self, arch, loss, pretrained_weights=None, cuda=True, fp16=False,
softmax_structure=softmax_struct, sm_pooling=sm_pooling)
# if args.distilled_param_path:
# model.load_state_dict(model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path))
elif arch == 'shufflenetv1':
model = ShuffleNet(width_mult=width, groups=groups)
else:
model = models.__dict__[arch]()
if pretrained_weights is not None:
Expand Down
7 changes: 5 additions & 2 deletions cnn/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def add_parser_arguments(parser):
custom_model_names = ['mobilenetv1']
custom_model_names = ['mobilenetv1', 'shufflenetv1']
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])) + custom_model_names
Expand Down Expand Up @@ -66,6 +66,8 @@ def add_parser_arguments(parser):
metavar='NSL', help='Number of structured layer (default 7)')
parser.add_argument('--width', default=1.0, type=float,
metavar='WIDTH', help='Width multiplier of the CNN (default 1.0)')
parser.add_argument('--groups', default=8, type=int,
metavar='GROUPS', help='Group parameter of ShuffleNet (default 8)')
parser.add_argument('--distilled-param-path', default='', type=str, metavar='PATH',
help='path to distilled parameters (default: none)')
parser.add_argument('--full-model-path', default='', type=str, metavar='PATH',
Expand Down Expand Up @@ -249,7 +251,8 @@ def _worker_init_fn(id):
pretrained_weights=pretrained_weights,
cuda = True, fp16 = args.fp16,
width=args.width, n_struct_layers=args.n_struct_layers,
struct=args.struct, softmax_struct=args.softmax_struct, sm_pooling=args.sm_pooling)
struct=args.struct, softmax_struct=args.softmax_struct, sm_pooling=args.sm_pooling,
groups=args.groups)

if args.arch == 'mobilenetv1' and args.distilled_param_path:
model_state = model_and_loss.model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path)
Expand Down
110 changes: 110 additions & 0 deletions cnn/shufflenet_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
'''ShuffleNet in PyTorch.
See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

import os, sys
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from butterfly import Butterfly

from cnn.mobilenet_imagenet import _make_divisible
from cnn.mobilenet_imagenet import Butterfly1x1Conv


class ShuffleBlock(nn.Module):
def __init__(self, groups):
super(ShuffleBlock, self).__init__()
self.groups = groups

def forward(self, x):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
N,C,H,W = x.size()
g = self.groups
return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W)


class Bottleneck(nn.Module):
def __init__(self, in_planes, out_planes, stride, groups, grouped_conv_1st_layer=True):
super(Bottleneck, self).__init__()
self.stride = stride

mid_planes = _make_divisible(out_planes // 4, groups)
if stride == 2: # Reduce out_planes due to concat
out_planes -= in_planes
g = groups if grouped_conv_1st_layer else 1 # No grouped conv for the first layer of stage 2
self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
self.bn1 = nn.BatchNorm2d(mid_planes)
self.shuffle1 = ShuffleBlock(groups=g)
self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
self.conv2.weight._no_wd = True
self.bn2 = nn.BatchNorm2d(mid_planes)
self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
self.bn3 = nn.BatchNorm2d(out_planes)

self.shortcut = nn.Sequential()
if stride == 2:
self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = self.shuffle1(out)
out = F.relu(self.bn2(self.conv2(out)), inplace=True)
out = self.bn3(self.conv3(out))
res = self.shortcut(x)
out = F.relu(torch.cat([out,res], 1), inplace=True) if self.stride==2 else F.relu(out+res, inplace=True)
return out


class ShuffleNet(nn.Module):
def __init__(self, num_classes=1000, groups=8, width_mult=1.0):
super(ShuffleNet, self).__init__()
num_blocks = [4, 8, 4]
groups_to_outplanes = {1: [144, 288, 576],
2: [200, 400, 800],
3: [240, 480, 960],
4: [272, 544, 1088],
8: [384, 768, 1536]}
out_planes = groups_to_outplanes[groups]
out_planes = [_make_divisible(p * width_mult, groups) for p in out_planes]

input_channel = _make_divisible(24 * width_mult, groups)
self.conv1 = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(input_channel)
self.in_planes = input_channel
self.stage2 = self._make_layer(out_planes[0], num_blocks[0], groups, grouped_conv_1st_layer=False)
self.stage3 = self._make_layer(out_planes[1], num_blocks[1], groups)
self.stage4 = self._make_layer(out_planes[2], num_blocks[2], groups)
self.linear = nn.Linear(out_planes[2], num_classes)

def _make_layer(self, out_planes, num_blocks, groups, grouped_conv_1st_layer=True):
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
layers.append(Bottleneck(self.in_planes, out_planes, stride=stride, groups=groups,
grouped_conv_1st_layer=grouped_conv_1st_layer))
self.in_planes = out_planes
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = self.maxpool(out)
out = self.stage2(out)
out = self.stage3(out)
out = self.stage4(out)
out = out.mean([2, 3])
out = self.linear(out)
return out


def test():
net = ShuffleNet()
x = torch.randn(1, 3, 224, 224)
y = net(x)
print(y)

# test()

0 comments on commit 2ec7686

Please sign in to comment.