From 75a911117767d5390745d1e3ce1c2f2495c7dd13 Mon Sep 17 00:00:00 2001 From: osmr Date: Thu, 9 Aug 2018 09:01:06 +0300 Subject: [PATCH] Working on pretrained model loading --- __init__.py | 0 gluon/models/model_store.py | 80 +++++++++++++++++++++++++++++++ gluon/models/others/densenet.py | 2 +- gluon/models/others/mobilenet.py | 4 +- gluon/models/others/resnet.py | 4 +- gluon/models/others/squeezenet.py | 2 +- gluon/models/resnet.py | 26 ++++++++-- 7 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 __init__.py create mode 100644 gluon/models/model_store.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gluon/models/model_store.py b/gluon/models/model_store.py new file mode 100644 index 000000000..918a8042a --- /dev/null +++ b/gluon/models/model_store.py @@ -0,0 +1,80 @@ +""" + Model store which provides pretrained models. +""" + +__all__ = ['get_model_file'] + +import os +import zipfile +import logging +from mxnet.gluon.utils import download, check_sha1 + +_model_sha1 = {name: (error, checksum, repo_release_tag) for name, error, checksum, repo_release_tag in [ + ('resnet18', '1005', 'c3152c77b769a05d06b2ca5a2aeeb79cb4e08ee1', 'v0.0.1'), + ('resnet34', '0793', '5b875f4934da8d83d44afc30d8e91362d3103115', 'v0.0.1')]} + +imgclsmob_repo_url = 'https://github.com/osmr/tmp1' + + +def get_model_name_suffix_data(model_name): + if model_name not in _model_sha1: + raise ValueError('Pretrained model for {name} is not available.'.format(name=model_name)) + error, sha1_hash, repo_release_tag = _model_sha1[model_name] + return error, sha1_hash, repo_release_tag + + +def get_model_file(model_name, + local_model_store_dir_path=os.path.join('~', '.mxnet', 'models')): + """ + Return location for the pretrained on local file system. This function will download from online model zoo when + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. + + Parameters + ---------- + model_name : str + Name of the model. + local_model_store_dir_path : str, default $MXNET_HOME/models + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name) + short_sha1 = sha1_hash[:8] + file_name = '{name}-{error}-{short_sha1}.params'.format( + name=model_name, + error=error, + short_sha1=short_sha1) + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) + file_path = os.path.join(local_model_store_dir_path, file_name) + if os.path.exists(file_path): + if check_sha1(file_path, sha1_hash): + return file_path + else: + logging.warning('Mismatch in the content of model file detected. Downloading again.') + else: + logging.info('Model file not found. Downloading to {}.'.format(file_path)) + + if not os.path.exists(local_model_store_dir_path): + os.makedirs(local_model_store_dir_path) + + zip_file_path = file_path + '.zip' + repo_url = os.environ.get('MXNET_GLUON_REPO', imgclsmob_repo_url) + download( + url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'.format( + repo_url=repo_url, + repo_release_tag=repo_release_tag, + file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(local_model_store_dir_path) + os.remove(zip_file_path) + + if check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError('Downloaded file has different hash. Please try again.') + diff --git a/gluon/models/others/densenet.py b/gluon/models/others/densenet.py index b83635129..08353b6c8 100644 --- a/gluon/models/others/densenet.py +++ b/gluon/models/others/densenet.py @@ -140,7 +140,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(), net = DenseNet(num_init_features, growth_rate, block_config, **kwargs) if pretrained: from ..model_store import get_model_file - net.load_parameters(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx) + net.load_parameters(get_model_file('densenet%d' % (num_layers), local_model_store_dir_path=root), ctx=ctx) return net def densenet121(**kwargs): diff --git a/gluon/models/others/mobilenet.py b/gluon/models/others/mobilenet.py index d7a1328f1..f39c1707f 100644 --- a/gluon/models/others/mobilenet.py +++ b/gluon/models/others/mobilenet.py @@ -213,7 +213,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(), if version_suffix in ('1.00', '0.50'): version_suffix = version_suffix[:-1] net.load_parameters( - get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx) + get_model_file('mobilenet%s' % version_suffix, local_model_store_dir_path=root), ctx=ctx) return net @@ -245,7 +245,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(), if version_suffix in ('1.00', '0.50'): version_suffix = version_suffix[:-1] net.load_parameters( - get_model_file('mobilenetv2_%s' % version_suffix, root=root), ctx=ctx) + get_model_file('mobilenetv2_%s' % version_suffix, local_model_store_dir_path=root), ctx=ctx) return net diff --git a/gluon/models/others/resnet.py b/gluon/models/others/resnet.py index 9f935692b..ebc831b87 100644 --- a/gluon/models/others/resnet.py +++ b/gluon/models/others/resnet.py @@ -385,8 +385,8 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(), net = resnet_class(block_class, layers, channels, **kwargs) if pretrained: from ..model_store import get_model_file - net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version), - root=root), ctx=ctx) + net.load_parameters(get_model_file('resnet%d_v%d' % (num_layers, version), + local_model_store_dir_path=root), ctx=ctx) return net def resnet18_v1(**kwargs): diff --git a/gluon/models/others/squeezenet.py b/gluon/models/others/squeezenet.py index 93d4520bb..d68b5774c 100644 --- a/gluon/models/others/squeezenet.py +++ b/gluon/models/others/squeezenet.py @@ -134,7 +134,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(), net = SqueezeNet(version, **kwargs) if pretrained: from ..model_store import get_model_file - net.load_parameters(get_model_file('squeezenet%s'%version, root=root), ctx=ctx) + net.load_parameters(get_model_file('squeezenet%s' % version, local_model_store_dir_path=root), ctx=ctx) return net def squeezenet1_0(**kwargs): diff --git a/gluon/models/resnet.py b/gluon/models/resnet.py index 2cc79623c..839fa27e2 100644 --- a/gluon/models/resnet.py +++ b/gluon/models/resnet.py @@ -7,6 +7,7 @@ 'resnet18_wd4', 'resnet34', 'resnet50', 'resnet50b', 'resnet101', 'resnet101b', 'resnet152', 'resnet152b', 'resnet200', 'resnet200b'] +import os from mxnet import cpu from mxnet.gluon import nn, HybridBlock @@ -402,6 +403,7 @@ def get_resnet(blocks, width_scale=1.0, pretrained=False, ctx=cpu(), + root=os.path.join('~', '.mxnet', 'models'), **kwargs): """ Create ResNet model with specific parameters. @@ -418,6 +420,8 @@ def get_resnet(blocks, Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. + root : str, default '~/.mxnet/models' + Location for keeping the model parameters. """ if blocks == 10: @@ -458,16 +462,28 @@ def get_resnet(blocks, channels = [[int(cij * width_scale) for cij in ci] for ci in channels] init_block_channels = int(init_block_channels * width_scale) - if pretrained: - raise ValueError("Pretrained model is not supported") - - return ResNet( + net = ResNet( channels=channels, init_block_channels=init_block_channels, bottleneck=bottleneck, conv1_stride=conv1_stride, **kwargs) + if pretrained: + if blocks in [18]: + from .model_store import get_model_file + net.load_parameters( + filename=get_model_file( + model_name='resnet{}{}'.format(blocks, '' if conv1_stride else 'b'), + local_model_store_dir_path=root), + ctx=ctx) + + else: + raise ValueError("Pretrained model is not supported") + #pass + + return net + def resnet10(**kwargs): """ @@ -745,7 +761,7 @@ def _test(): for model in models: - net = model() + net = model(pretrained=True) ctx = mx.cpu() net.initialize(ctx=ctx)