Skip to content

Commit

Permalink
update prune_paddle_model.py (#1209)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng-Bicheng authored Mar 26, 2024
1 parent cfd3242 commit 2cd690b
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions tools/paddle/prune_paddle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import os


def new_prepend_feed_ops(inference_program,
feed_target_names,
feed_holder_name='feed'):
def prepend_feed_ops(program, feed_target_names):
if len(feed_target_names) == 0:
return

global_block = inference_program.global_block()
global_block = program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
name='feed',
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

Expand All @@ -33,13 +31,13 @@ def new_prepend_feed_ops(inference_program,
attrs={'col': i})


def append_fetch_ops(program, fetch_target_names, fetch_holder_name='fetch'):
def append_fetch_ops(program, fetch_target_names):
"""
In this palce, we will add the fetch op
"""
global_block = program.global_block()
fetch_var = global_block.create_var(
name=fetch_holder_name,
name='fetch',
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
print("the len of fetch_target_names:%d" % (len(fetch_target_names)))
Expand All @@ -51,16 +49,20 @@ def append_fetch_ops(program, fetch_target_names, fetch_holder_name='fetch'):
attrs={'col': i})


def insert_fetch(program, fetchs, fetch_holder_name="fetch"):
def insert_by_op_type(program, op_names, op_type):
global_block = program.global_block()
need_to_remove_op_index = list()
for i, op in enumerate(global_block.ops):
if op.type == 'fetch':
if op.type == op_type:
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
program.desc.flush()
append_fetch_ops(program, fetchs, fetch_holder_name)

if op_type == "feed":
prepend_feed_ops(program, op_names)
else:
append_fetch_ops(program, op_names)


def parse_arguments():
Expand All @@ -73,6 +75,10 @@ def parse_arguments():
'--model_filename', required=True, help='The input model file name.')
parser.add_argument(
'--params_filename', required=True, help='The parameters file name.')
parser.add_argument(
'--input_names',
nargs='+',
help='The inputs of pruned model.')
parser.add_argument(
'--output_names',
required=True,
Expand All @@ -94,17 +100,25 @@ def parse_arguments():
sys.exit(-1)

paddle.enable_static()
paddle.static.io.prepend_feed_ops = new_prepend_feed_ops
print("Start to load paddle model...")
exe = static.Executor(paddle.CPUPlace())
[program, feed_target_names, fetch_targets] = static.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
insert_fetch(program, args.output_names)
feed_vars = [program.global_block().var(name) for name in feed_target_names]
fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names]

if args.input_names is not None:
insert_by_op_type(program, args.input_names, 'feed')
feed_vars = [program.global_block().var(name) for name in args.input_names]
else:
feed_vars = [program.global_block().var(name) for name in feed_target_names]

if args.output_names is not None:
insert_by_op_type(program, args.output_names, 'fetch')
fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names]
else:
fetch_vars = [out_var for out_var in fetch_targets]

model_name = args.model_filename.split(".")[0]
path_prefix = os.path.join(args.save_dir, model_name)
Expand Down

0 comments on commit 2cd690b

Please sign in to comment.