Skip to content

Commit

Permalink
Working on pretrained model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Aug 9, 2018
1 parent abe1e5e commit 75a9111
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 11 deletions.
Empty file added __init__.py
Empty file.
80 changes: 80 additions & 0 deletions gluon/models/model_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Model store which provides pretrained models.
"""

__all__ = ['get_model_file']

import os
import zipfile
import logging
from mxnet.gluon.utils import download, check_sha1

_model_sha1 = {name: (error, checksum, repo_release_tag) for name, error, checksum, repo_release_tag in [
('resnet18', '1005', 'c3152c77b769a05d06b2ca5a2aeeb79cb4e08ee1', 'v0.0.1'),
('resnet34', '0793', '5b875f4934da8d83d44afc30d8e91362d3103115', 'v0.0.1')]}

imgclsmob_repo_url = 'https://github.com/osmr/tmp1'


def get_model_name_suffix_data(model_name):
if model_name not in _model_sha1:
raise ValueError('Pretrained model for {name} is not available.'.format(name=model_name))
error, sha1_hash, repo_release_tag = _model_sha1[model_name]
return error, sha1_hash, repo_release_tag


def get_model_file(model_name,
local_model_store_dir_path=os.path.join('~', '.mxnet', 'models')):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Parameters
----------
model_name : str
Name of the model.
local_model_store_dir_path : str, default $MXNET_HOME/models
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name)
short_sha1 = sha1_hash[:8]
file_name = '{name}-{error}-{short_sha1}.params'.format(
name=model_name,
error=error,
short_sha1=short_sha1)
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
file_path = os.path.join(local_model_store_dir_path, file_name)
if os.path.exists(file_path):
if check_sha1(file_path, sha1_hash):
return file_path
else:
logging.warning('Mismatch in the content of model file detected. Downloading again.')
else:
logging.info('Model file not found. Downloading to {}.'.format(file_path))

if not os.path.exists(local_model_store_dir_path):
os.makedirs(local_model_store_dir_path)

zip_file_path = file_path + '.zip'
repo_url = os.environ.get('MXNET_GLUON_REPO', imgclsmob_repo_url)
download(
url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'.format(
repo_url=repo_url,
repo_release_tag=repo_release_tag,
file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(local_model_store_dir_path)
os.remove(zip_file_path)

if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')

2 changes: 1 addition & 1 deletion gluon/models/others/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_parameters(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
net.load_parameters(get_model_file('densenet%d' % (num_layers), local_model_store_dir_path=root), ctx=ctx)
return net

def densenet121(**kwargs):
Expand Down
4 changes: 2 additions & 2 deletions gluon/models/others/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
net.load_parameters(
get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx)
get_model_file('mobilenet%s' % version_suffix, local_model_store_dir_path=root), ctx=ctx)
return net


Expand Down Expand Up @@ -245,7 +245,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
net.load_parameters(
get_model_file('mobilenetv2_%s' % version_suffix, root=root), ctx=ctx)
get_model_file('mobilenetv2_%s' % version_suffix, local_model_store_dir_path=root), ctx=ctx)
return net


Expand Down
4 changes: 2 additions & 2 deletions gluon/models/others/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
net = resnet_class(block_class, layers, channels, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
root=root), ctx=ctx)
net.load_parameters(get_model_file('resnet%d_v%d' % (num_layers, version),
local_model_store_dir_path=root), ctx=ctx)
return net

def resnet18_v1(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion gluon/models/others/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
net = SqueezeNet(version, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_parameters(get_model_file('squeezenet%s'%version, root=root), ctx=ctx)
net.load_parameters(get_model_file('squeezenet%s' % version, local_model_store_dir_path=root), ctx=ctx)
return net

def squeezenet1_0(**kwargs):
Expand Down
26 changes: 21 additions & 5 deletions gluon/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
'resnet18_wd4', 'resnet34', 'resnet50', 'resnet50b', 'resnet101', 'resnet101b', 'resnet152', 'resnet152b',
'resnet200', 'resnet200b']

import os
from mxnet import cpu
from mxnet.gluon import nn, HybridBlock

Expand Down Expand Up @@ -402,6 +403,7 @@ def get_resnet(blocks,
width_scale=1.0,
pretrained=False,
ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'),
**kwargs):
"""
Create ResNet model with specific parameters.
Expand All @@ -418,6 +420,8 @@ def get_resnet(blocks,
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""

if blocks == 10:
Expand Down Expand Up @@ -458,16 +462,28 @@ def get_resnet(blocks,
channels = [[int(cij * width_scale) for cij in ci] for ci in channels]
init_block_channels = int(init_block_channels * width_scale)

if pretrained:
raise ValueError("Pretrained model is not supported")

return ResNet(
net = ResNet(
channels=channels,
init_block_channels=init_block_channels,
bottleneck=bottleneck,
conv1_stride=conv1_stride,
**kwargs)

if pretrained:
if blocks in [18]:
from .model_store import get_model_file
net.load_parameters(
filename=get_model_file(
model_name='resnet{}{}'.format(blocks, '' if conv1_stride else 'b'),
local_model_store_dir_path=root),
ctx=ctx)

else:
raise ValueError("Pretrained model is not supported")
#pass

return net


def resnet10(**kwargs):
"""
Expand Down Expand Up @@ -745,7 +761,7 @@ def _test():

for model in models:

net = model()
net = model(pretrained=True)

ctx = mx.cpu()
net.initialize(ctx=ctx)
Expand Down

0 comments on commit 75a9111

Please sign in to comment.