Skip to content

Commit

Permalink
load model by checkpoint step
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Dec 4, 2017
1 parent b29f32a commit c2d5200
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
9 changes: 5 additions & 4 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hparam import Hparam


def convert(logdir, writer, queue=False):
def convert(logdir, step, writer, queue=False):
hp = Hparam.get_global_hparam()

# Load graph
Expand All @@ -33,12 +33,12 @@ def convert(logdir, writer, queue=False):
with tf.Session(config=session_conf) as sess:
# Load trained model
sess.run(tf.global_variables_initializer())
model.load(sess, 'convert', logdir=logdir)
model.load(sess, 'convert', logdir=logdir, step=step)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

epoch, gs = Model.get_epoch_and_global_step(logdir)
epoch, gs = Model.get_epoch_and_global_step(logdir, step=step)

if queue:
pred_log_specs, y_log_spec, ppgs = sess.run([model(), model.y_spec, model.ppgs])
Expand Down Expand Up @@ -93,6 +93,7 @@ def convert(logdir, writer, queue=False):
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('case', type=str, help='experiment case name')
parser.add_argument('-step', type=int, help='checkpoint step to load', nargs='?', default=None)
arguments = parser.parse_args()
return arguments

Expand All @@ -108,7 +109,7 @@ def get_arguments():
s = datetime.datetime.now()

writer = tf.summary.FileWriter(logdir)
convert(logdir=logdir, writer=writer)
convert(logdir=logdir, step=args.step, writer=writer)
writer.close()

e = datetime.datetime.now()
Expand Down
4 changes: 2 additions & 2 deletions hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ train2:

# train
batch_size: 32
lr: 0.0005
lr_cyclic_margin: 0.0004
lr: 0.0003
lr_cyclic_margin: 0.0002
lr_cyclic_steps: 5000
clip_value_max: 3.
clip_value_min: -3.
Expand Down
54 changes: 28 additions & 26 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python

import os

import tensorflow as tf

from data_load import get_batch_queue, load_vocab
from modules import prenet, cbhg
import glob, os
from utils import split_path


class Model:
Expand Down Expand Up @@ -125,52 +125,55 @@ def loss_net2(self):
return loss

@staticmethod
def load(sess, mode, logdir, logdir2=None):

def print_model_loaded(mode, logdir):
model_name = Model.get_model_name(logdir)
def load(sess, mode, logdir, logdir2=None, step=None):
def print_model_loaded(mode, logdir, step):
model_name = Model.get_model_name(logdir, step=step)
print('Model loaded. mode: {}, model_name: {}'.format(mode, model_name))

if mode in ['train1', 'test1']:
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net1')
if Model._load_variables(sess, logdir, var_list=var_list):
print_model_loaded(mode, logdir)
if Model._load_variables(sess, logdir, var_list=var_list, step=step):
print_model_loaded(mode, logdir, step)

elif mode == 'train2':
var_list1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net1')
if Model._load_variables(sess, logdir, var_list=var_list1):
print_model_loaded(mode, logdir)
if Model._load_variables(sess, logdir, var_list=var_list1, step=step):
print_model_loaded(mode, logdir, step)

var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net2')
if Model._load_variables(sess, logdir2, var_list=var_list2):
print_model_loaded(mode, logdir2)
if Model._load_variables(sess, logdir2, var_list=var_list2, step=step):
print_model_loaded(mode, logdir2, step)

elif mode in ['test2', 'convert']:
if Model._load_variables(sess, logdir, var_list=None): # Load all variables
print_model_loaded(mode, logdir)
if Model._load_variables(sess, logdir, var_list=None, step=step): # Load all variables
print_model_loaded(mode, logdir, step)

@staticmethod
def _load_variables(sess, logdir, var_list):
ckpt = tf.train.latest_checkpoint(logdir)
if ckpt:
def _load_variables(sess, logdir, var_list, step=None):
model_name = Model.get_model_name(logdir, step)
if model_name:
ckpt = os.path.join(logdir, model_name)
tf.train.Saver(var_list=var_list).restore(sess, ckpt)
return True
else:
return False

@staticmethod
def get_model_name(logdir):
path = '{}/checkpoint'.format(logdir)
if os.path.exists(path):
ckpt_path = open(path, 'r').read().split('"')[1]
_, model_name = os.path.split(ckpt_path)
def get_model_name(logdir, step=None):
model_name = None
if step:
paths = glob.glob('{}/*step_{}.index*'.format(logdir, step))
if paths:
_, model_name, _ = split_path(paths[0])
else:
model_name = None
ckpt = tf.train.latest_checkpoint(logdir)
if ckpt:
_, model_name = os.path.split(ckpt)
return model_name

@staticmethod
def get_epoch_and_global_step(logdir):
model_name = Model.get_model_name(logdir)
def get_epoch_and_global_step(logdir, step=None):
model_name = Model.get_model_name(logdir, step)
if model_name:
tokens = model_name.split('_')
epoch, gs = int(tokens[1]), int(tokens[3])
Expand All @@ -180,7 +183,6 @@ def get_epoch_and_global_step(logdir):

@staticmethod
def all_model_names(logdir):
import glob, os
path = '{}/*.meta'.format(logdir)
model_names = map(lambda f: os.path.basename(f).replace('.meta', ''), glob.glob(path))
return model_names
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#!/usr/bin/env python

import os
import glob


def split_path(path):
Expand Down

0 comments on commit c2d5200

Please sign in to comment.