Skip to content

Commit

Permalink
Fix save_inference_model bug in paddlehub
Browse files Browse the repository at this point in the history
  • Loading branch information
rainyfly authored Sep 16, 2022
1 parent 196f7e6 commit 3fcdd77
Showing 1 changed file with 56 additions and 49 deletions.
105 changes: 56 additions & 49 deletions paddlehub/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@


class InvalidHubModule(Exception):

def __init__(self, directory: str):
self.directory = directory

Expand Down Expand Up @@ -200,11 +199,12 @@ def save_inference_model(self,
for key, _sub_module in self.sub_modules().items():
try:
sub_dirname = os.path.normpath(os.path.join(dirname, key))
_sub_module.save_inference_model(sub_dirname,
include_sub_modules=include_sub_modules,
model_filename=model_filename,
params_filename=params_filename,
combined=combined)
_sub_module.save_inference_model(
sub_dirname,
include_sub_modules=include_sub_modules,
model_filename=model_filename,
params_filename=params_filename,
combined=combined)
except:
utils.record_exception('Failed to save sub module {}'.format(_sub_module.name))

Expand All @@ -231,14 +231,11 @@ def save_inference_model(self,
if not self._pretrained_model_path:
raise RuntimeError('Module {} does not support exporting models in Paddle Inference format.'.format(
self.name))
elif not os.path.exists(self._pretrained_model_path):
elif not os.path.exists(
self._pretrained_model_path) and not os.path.exists(self._pretrained_model_path + '.pdmodel'):
log.logger.warning('The model path of Module {} does not exist.'.format(self.name))
return

model_filename = '__model__' if not model_filename else model_filename
if combined:
params_filename = '__params__' if not params_filename else params_filename

place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

Expand All @@ -253,21 +250,25 @@ def save_inference_model(self,

if os.path.exists(os.path.join(self._pretrained_model_path, '__params__')):
_params_filename = '__params__'
if _model_filename is not None and _params_filename is not None:
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
self._pretrained_model_path,
executor=exe,
model_filename=_model_filename,
params_filename=_params_filename,
)
else:
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
self._pretrained_model_path, executor=exe)

program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self._pretrained_model_path,
executor=exe,
model_filename=_model_filename,
params_filename=_params_filename,
)

paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
global_block = program.global_block()
feed_vars = [global_block.var(item) for item in feeded_var_names]

path_prefix = dirname
if os.path.isdir(dirname):
path_prefix = os.path.join(dirname, 'model')
paddle.static.save_inference_model(
path_prefix, feed_vars=feed_vars, fetch_vars=target_vars, executor=exe, program=program)

log.logger.info('Paddle Inference model saved in {}.'.format(dirname))

Expand Down Expand Up @@ -337,17 +338,19 @@ def export_onnx_model(self,

save_file = os.path.join(dirname, '{}.onnx'.format(self.name))

program, inputs, outputs = paddle.static.load_inference_model(dirname=self._pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)
program, inputs, outputs = paddle.static.load_inference_model(
dirname=self._pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)

paddle2onnx.program2onnx(program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)
paddle2onnx.program2onnx(
program=program,
scope=paddle.static.global_scope(),
feed_var_names=inputs,
target_vars=outputs,
save_file=save_file,
**kwargs)


class Module(object):
Expand Down Expand Up @@ -387,13 +390,14 @@ def __new__(cls,
from paddlehub.server.server import CacheUpdater
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name:
module = cls.init_with_name(name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch,
**kwargs)
module = cls.init_with_name(
name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch,
**kwargs)
CacheUpdater("update_cache", module=name, version=version).start()
elif directory:
module = cls.init_with_directory(directory=directory, **kwargs)
Expand Down Expand Up @@ -485,12 +489,13 @@ def init_with_name(cls,
manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch)
user_module_cls = manager.install(
name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch)

directory = manager._get_normalized_path(user_module_cls.name)

Expand Down Expand Up @@ -555,7 +560,9 @@ def _wrapper(cls: Generic) -> Generic:
_bases.append(_b)
_bases.append(_meta)
_bases = tuple(_bases)
wrap_cls = builtins.type(cls.__name__, _bases, dict(cls.__dict__))
attr_dict = dict(cls.__dict__)
attr_dict.pop('__dict__', None)
wrap_cls = builtins.type(cls.__name__, _bases, attr_dict)

wrap_cls.name = name
wrap_cls.version = utils.Version(version)
Expand Down

0 comments on commit 3fcdd77

Please sign in to comment.