From 9aace3c044ef1aba5fb3594d32454bab5a09d679 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 26 Aug 2019 19:18:28 -0700 Subject: [PATCH] Support loading MobileNet from distilled parameters --- cnn/imagenet_amp.py | 6 ++++++ cnn/mobilenet_imagenet.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/cnn/imagenet_amp.py b/cnn/imagenet_amp.py index 1fa9fb0..c935a9f 100644 --- a/cnn/imagenet_amp.py +++ b/cnn/imagenet_amp.py @@ -53,6 +53,10 @@ metavar='NSL', help='Number of structured layer (default 7)') parser.add_argument('--width', default=1.0, type=float, metavar='WIDTH', help='Width multiplier of the CNN (default 1.0)') +parser.add_argument('--distilled-param-path', default='', type=str, metavar='PATH', + help='path to distilled parameters (default: none)') +parser.add_argument('--full-model-path', default='', type=str, metavar='PATH', + help='path to full model checkpoint (default: none)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', @@ -139,6 +143,8 @@ def main(): elif args.arch == 'mobilenetv1_struct': model = MobileNet(width_mult=args.width, structure=[args.struct] * args.n_struct_layers, softmax_structure=args.softmax_struct) + if args.distilled_param_path: + model.load_state_dict(model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path)) else: model = models.__dict__[args.arch]() if args.local_rank == 0: diff --git a/cnn/mobilenet_imagenet.py b/cnn/mobilenet_imagenet.py index 0b9d100..40180c8 100644 --- a/cnn/mobilenet_imagenet.py +++ b/cnn/mobilenet_imagenet.py @@ -139,6 +139,21 @@ def forward(self, x): out = self.linear(out) return out + def mixed_model_state_dict(self, full_model_path, distilled_param_path): + current_state_dict_keys = self.state_dict().keys() + full_model_state_dict = torch.load(full_model_path, map_location='cpu')['state_dict'] + full_model_state_dict = {name.replace('module.', ''): param for name, param in full_model_state_dict.items()} + distilled_params = torch.load(distilled_param_path, map_location='cpu') + state_dict = {name: param for name, param in full_model_state_dict.items() if name in current_state_dict_keys} + for i, struct in enumerate(self.structure): + # Only support butterfly for now + if struct.startswith('odo') or struct.startswith('regular'): + layer = f'layers.{i}.conv2' + nblocks = int(struct.split('_')[1]) + structured_param = distilled_params[layer, nblocks] + state_dict.update({layer + '.' + name: param for name, param in structured_param.items()}) + return state_dict + def test(): net = MobileNet()