diff --git a/cnn/imagenet/training.py b/cnn/imagenet/training.py index c6ab1d0..e9b9ea2 100644 --- a/cnn/imagenet/training.py +++ b/cnn/imagenet/training.py @@ -17,6 +17,7 @@ from cnn.mobilenet_imagenet import MobileNet from cnn.mobilenet_imagenet import Butterfly1x1Conv +from cnn.shufflenet_imagenet import ShuffleNet try: from apex.parallel import DistributedDataParallel as DDP @@ -28,7 +29,8 @@ class ModelAndLoss(nn.Module): def __init__(self, arch, loss, pretrained_weights=None, cuda=True, fp16=False, - width=1.0, n_struct_layers=0, struct='D', softmax_struct='D', sm_pooling=1): + width=1.0, n_struct_layers=0, struct='D', softmax_struct='D', sm_pooling=1, + groups=8): super(ModelAndLoss, self).__init__() self.arch = arch @@ -39,6 +41,8 @@ def __init__(self, arch, loss, pretrained_weights=None, cuda=True, fp16=False, softmax_structure=softmax_struct, sm_pooling=sm_pooling) # if args.distilled_param_path: # model.load_state_dict(model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path)) + elif arch == 'shufflenetv1': + model = ShuffleNet(width_mult=width, groups=groups) else: model = models.__dict__[arch]() if pretrained_weights is not None: diff --git a/cnn/imagenet_main.py b/cnn/imagenet_main.py index 20c4a27..ed644c3 100644 --- a/cnn/imagenet_main.py +++ b/cnn/imagenet_main.py @@ -37,7 +37,7 @@ def add_parser_arguments(parser): - custom_model_names = ['mobilenetv1'] + custom_model_names = ['mobilenetv1', 'shufflenetv1'] model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + custom_model_names @@ -66,6 +66,8 @@ def add_parser_arguments(parser): metavar='NSL', help='Number of structured layer (default 7)') parser.add_argument('--width', default=1.0, type=float, metavar='WIDTH', help='Width multiplier of the CNN (default 1.0)') + parser.add_argument('--groups', default=8, type=int, + metavar='GROUPS', help='Group parameter of ShuffleNet (default 8)') parser.add_argument('--distilled-param-path', default='', type=str, metavar='PATH', help='path to distilled parameters (default: none)') parser.add_argument('--full-model-path', default='', type=str, metavar='PATH', @@ -249,7 +251,8 @@ def _worker_init_fn(id): pretrained_weights=pretrained_weights, cuda = True, fp16 = args.fp16, width=args.width, n_struct_layers=args.n_struct_layers, - struct=args.struct, softmax_struct=args.softmax_struct, sm_pooling=args.sm_pooling) + struct=args.struct, softmax_struct=args.softmax_struct, sm_pooling=args.sm_pooling, + groups=args.groups) if args.arch == 'mobilenetv1' and args.distilled_param_path: model_state = model_and_loss.model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path) diff --git a/cnn/shufflenet_imagenet.py b/cnn/shufflenet_imagenet.py new file mode 100644 index 0000000..efb312f --- /dev/null +++ b/cnn/shufflenet_imagenet.py @@ -0,0 +1,110 @@ +'''ShuffleNet in PyTorch. + +See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +import os, sys +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) +from butterfly import Butterfly + +from cnn.mobilenet_imagenet import _make_divisible +from cnn.mobilenet_imagenet import Butterfly1x1Conv + + +class ShuffleBlock(nn.Module): + def __init__(self, groups): + super(ShuffleBlock, self).__init__() + self.groups = groups + + def forward(self, x): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + N,C,H,W = x.size() + g = self.groups + return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) + + +class Bottleneck(nn.Module): + def __init__(self, in_planes, out_planes, stride, groups, grouped_conv_1st_layer=True): + super(Bottleneck, self).__init__() + self.stride = stride + + mid_planes = _make_divisible(out_planes // 4, groups) + if stride == 2: # Reduce out_planes due to concat + out_planes -= in_planes + g = groups if grouped_conv_1st_layer else 1 # No grouped conv for the first layer of stage 2 + self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes) + self.shuffle1 = ShuffleBlock(groups=g) + self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) + self.conv2.weight._no_wd = True + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) + self.bn3 = nn.BatchNorm2d(out_planes) + + self.shortcut = nn.Sequential() + if stride == 2: + self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + out = self.shuffle1(out) + out = F.relu(self.bn2(self.conv2(out)), inplace=True) + out = self.bn3(self.conv3(out)) + res = self.shortcut(x) + out = F.relu(torch.cat([out,res], 1), inplace=True) if self.stride==2 else F.relu(out+res, inplace=True) + return out + + +class ShuffleNet(nn.Module): + def __init__(self, num_classes=1000, groups=8, width_mult=1.0): + super(ShuffleNet, self).__init__() + num_blocks = [4, 8, 4] + groups_to_outplanes = {1: [144, 288, 576], + 2: [200, 400, 800], + 3: [240, 480, 960], + 4: [272, 544, 1088], + 8: [384, 768, 1536]} + out_planes = groups_to_outplanes[groups] + out_planes = [_make_divisible(p * width_mult, groups) for p in out_planes] + + input_channel = _make_divisible(24 * width_mult, groups) + self.conv1 = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.bn1 = nn.BatchNorm2d(input_channel) + self.in_planes = input_channel + self.stage2 = self._make_layer(out_planes[0], num_blocks[0], groups, grouped_conv_1st_layer=False) + self.stage3 = self._make_layer(out_planes[1], num_blocks[1], groups) + self.stage4 = self._make_layer(out_planes[2], num_blocks[2], groups) + self.linear = nn.Linear(out_planes[2], num_classes) + + def _make_layer(self, out_planes, num_blocks, groups, grouped_conv_1st_layer=True): + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append(Bottleneck(self.in_planes, out_planes, stride=stride, groups=groups, + grouped_conv_1st_layer=grouped_conv_1st_layer)) + self.in_planes = out_planes + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + out = self.maxpool(out) + out = self.stage2(out) + out = self.stage3(out) + out = self.stage4(out) + out = out.mean([2, 3]) + out = self.linear(out) + return out + + +def test(): + net = ShuffleNet() + x = torch.randn(1, 3, 224, 224) + y = net(x) + print(y) + +# test()