diff --git a/convert.py b/convert.py index 0a9209c4..f99034e7 100644 --- a/convert.py +++ b/convert.py @@ -16,7 +16,7 @@ from hparam import Hparam -def convert(logdir, writer, queue=False): +def convert(logdir, step, writer, queue=False): hp = Hparam.get_global_hparam() # Load graph @@ -33,12 +33,12 @@ def convert(logdir, writer, queue=False): with tf.Session(config=session_conf) as sess: # Load trained model sess.run(tf.global_variables_initializer()) - model.load(sess, 'convert', logdir=logdir) + model.load(sess, 'convert', logdir=logdir, step=step) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) - epoch, gs = Model.get_epoch_and_global_step(logdir) + epoch, gs = Model.get_epoch_and_global_step(logdir, step=step) if queue: pred_log_specs, y_log_spec, ppgs = sess.run([model(), model.y_spec, model.ppgs]) @@ -93,6 +93,7 @@ def convert(logdir, writer, queue=False): def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument('case', type=str, help='experiment case name') + parser.add_argument('-step', type=int, help='checkpoint step to load', nargs='?', default=None) arguments = parser.parse_args() return arguments @@ -108,7 +109,7 @@ def get_arguments(): s = datetime.datetime.now() writer = tf.summary.FileWriter(logdir) - convert(logdir=logdir, writer=writer) + convert(logdir=logdir, step=args.step, writer=writer) writer.close() e = datetime.datetime.now() diff --git a/hparams/default.yaml b/hparams/default.yaml index 25fed975..4361a85e 100644 --- a/hparams/default.yaml +++ b/hparams/default.yaml @@ -55,8 +55,8 @@ train2: # train batch_size: 32 - lr: 0.0005 - lr_cyclic_margin: 0.0004 + lr: 0.0003 + lr_cyclic_margin: 0.0002 lr_cyclic_steps: 5000 clip_value_max: 3. clip_value_min: -3. diff --git a/models.py b/models.py index 58635f64..51e826e8 100644 --- a/models.py +++ b/models.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- #!/usr/bin/env python -import os - import tensorflow as tf from data_load import get_batch_queue, load_vocab from modules import prenet, cbhg +import glob, os +from utils import split_path class Model: @@ -125,52 +125,55 @@ def loss_net2(self): return loss @staticmethod - def load(sess, mode, logdir, logdir2=None): - - def print_model_loaded(mode, logdir): - model_name = Model.get_model_name(logdir) + def load(sess, mode, logdir, logdir2=None, step=None): + def print_model_loaded(mode, logdir, step): + model_name = Model.get_model_name(logdir, step=step) print('Model loaded. mode: {}, model_name: {}'.format(mode, model_name)) if mode in ['train1', 'test1']: var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net1') - if Model._load_variables(sess, logdir, var_list=var_list): - print_model_loaded(mode, logdir) + if Model._load_variables(sess, logdir, var_list=var_list, step=step): + print_model_loaded(mode, logdir, step) elif mode == 'train2': var_list1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net1') - if Model._load_variables(sess, logdir, var_list=var_list1): - print_model_loaded(mode, logdir) + if Model._load_variables(sess, logdir, var_list=var_list1, step=step): + print_model_loaded(mode, logdir, step) var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net2') - if Model._load_variables(sess, logdir2, var_list=var_list2): - print_model_loaded(mode, logdir2) + if Model._load_variables(sess, logdir2, var_list=var_list2, step=step): + print_model_loaded(mode, logdir2, step) elif mode in ['test2', 'convert']: - if Model._load_variables(sess, logdir, var_list=None): # Load all variables - print_model_loaded(mode, logdir) + if Model._load_variables(sess, logdir, var_list=None, step=step): # Load all variables + print_model_loaded(mode, logdir, step) @staticmethod - def _load_variables(sess, logdir, var_list): - ckpt = tf.train.latest_checkpoint(logdir) - if ckpt: + def _load_variables(sess, logdir, var_list, step=None): + model_name = Model.get_model_name(logdir, step) + if model_name: + ckpt = os.path.join(logdir, model_name) tf.train.Saver(var_list=var_list).restore(sess, ckpt) return True else: return False @staticmethod - def get_model_name(logdir): - path = '{}/checkpoint'.format(logdir) - if os.path.exists(path): - ckpt_path = open(path, 'r').read().split('"')[1] - _, model_name = os.path.split(ckpt_path) + def get_model_name(logdir, step=None): + model_name = None + if step: + paths = glob.glob('{}/*step_{}.index*'.format(logdir, step)) + if paths: + _, model_name, _ = split_path(paths[0]) else: - model_name = None + ckpt = tf.train.latest_checkpoint(logdir) + if ckpt: + _, model_name = os.path.split(ckpt) return model_name @staticmethod - def get_epoch_and_global_step(logdir): - model_name = Model.get_model_name(logdir) + def get_epoch_and_global_step(logdir, step=None): + model_name = Model.get_model_name(logdir, step) if model_name: tokens = model_name.split('_') epoch, gs = int(tokens[1]), int(tokens[3]) @@ -180,7 +183,6 @@ def get_epoch_and_global_step(logdir): @staticmethod def all_model_names(logdir): - import glob, os path = '{}/*.meta'.format(logdir) model_names = map(lambda f: os.path.basename(f).replace('.meta', ''), glob.glob(path)) return model_names diff --git a/utils.py b/utils.py index 53111b7f..999b5bb8 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,7 @@ #!/usr/bin/env python import os +import glob def split_path(path):