diff --git a/convert.py b/convert.py index bdf747ff..e5034f51 100644 --- a/convert.py +++ b/convert.py @@ -10,15 +10,13 @@ from models import Model import numpy as np from audio import spectrogram2wav, inv_preemphasis, db_to_amp -from hparam import logdir_path import datetime import tensorflow as tf -from hparam import Hparam +from hparam import hparam as hp from utils import denormalize_0_1 def convert(logdir, writer, queue=False, step=None): - hp = Hparam.get_global_hparam() # Load graph model = Model(mode="convert", batch_size=hp.convert.batch_size, hp=hp, queue=queue) @@ -107,11 +105,10 @@ def get_arguments(): if __name__ == '__main__': args = get_arguments() - case = args.case - logdir = '{}/{}/train2'.format(logdir_path, case) - Hparam(case).set_as_global_hparam() + hp.set_hparam_yaml(args.case) + logdir = '{}/train2'.format(hp.logdir) - print('case: {}, logdir: {}'.format(case, logdir)) + print('case: {}, logdir: {}'.format(args.case, logdir)) s = datetime.datetime.now() diff --git a/data_load.py b/data_load.py index efb5f314..939ad4c9 100644 --- a/data_load.py +++ b/data_load.py @@ -9,11 +9,10 @@ import tensorflow as tf from tensorflow.python.platform import tf_logging as logging -from hparam import Hparam from audio import preemphasis, amp_to_db import numpy as np import librosa -from hparam import data_path_base +from hparam import hparam as hp from utils import normalize_0_1 @@ -32,7 +31,6 @@ def wav_random_crop(wav, sr, duration): def get_mfccs_and_phones(wav_file, sr, length, trim=False, random_crop=True): - hp = Hparam.get_global_hparam() '''This is applied in `train1` or `test1` phase. ''' @@ -84,7 +82,6 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat '''This is applied in `train2`, `test2` or `convert` phase. ''' - hp = Hparam.get_global_hparam() # Load wav, _ = librosa.load(wav_file, sr=hp.default.sr) @@ -106,7 +103,6 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat def _get_mfcc_log_spec_and_log_mel_spec(wav, preemphasis_coeff, n_fft, win_length, hop_length): - hp = Hparam.get_global_hparam() # Pre-emphasis y_preem = preemphasis(wav, coeff=preemphasis_coeff) @@ -195,7 +191,6 @@ def get_mfccs_and_phones_queue(wav_file): extracts mfccs (inputs), and phones (target), then enqueue them again. This is applied in `train1` or `test1` phase. ''' - hp = Hparam.get_global_hparam() mfccs, phns = get_mfccs_and_phones(wav_file, hp.default.sr, length=int(hp.default.duration / hp.default.frame_shift + 1)) return mfccs, phns @@ -207,7 +202,6 @@ def get_mfccs_and_spectrogram_queue(wav_file): extracts mfccs and spectrogram, then enqueue them again. This is applied in `train2` or `test2` phase. ''' - hp = Hparam.get_global_hparam() mfccs, spec, mel = get_mfccs_and_spectrogram(wav_file, duration=hp.default.duration, win_length=hp.default.win_length, hop_length=hp.default.hop_length) @@ -218,7 +212,6 @@ def get_batch_queue(mode, batch_size): '''Loads data and put them in mini batch queues. mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`. ''' - hp = Hparam.get_global_hparam() if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'): raise Exception("invalid mode={}".format(mode)) @@ -273,7 +266,6 @@ def get_batch(mode, batch_size): '''Loads data. mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`. ''' - hp = Hparam.get_global_hparam() if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'): raise Exception("invalid mode={}".format(mode)) @@ -297,7 +289,6 @@ def get_batch(mode, batch_size): # TODO generalize for all mode def get_wav_batch(mode, batch_size): - hp = Hparam.get_global_hparam() with tf.device('/cpu:0'): # Load data @@ -401,26 +392,6 @@ def load_vocab(): def load_data(mode): - '''Loads the list of sound files. - mode: A string. One of the phases below: - `train1`: TIMIT TRAIN waveform -> mfccs (inputs) -> PGGs -> phones (target) (ce loss) - `test1`: TIMIT TEST waveform -> mfccs (inputs) -> PGGs -> phones (target) (accuracy) - `train2`: ARCTIC SLT waveform -> mfccs -> PGGs (inputs) -> spectrogram (target)(l2 loss) - `test2`: ARCTIC SLT waveform -> mfccs -> PGGs (inputs) -> spectrogram (target)(accuracy) - `convert`: ARCTIC BDL waveform -> mfccs (inputs) -> PGGs -> spectrogram -> waveform (output) - ''' - hp = Hparam.get_global_hparam() - - if mode == "train1": - wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train1.data_path)) - elif mode == "test1": - wav_files = glob.glob('{}/{}'.format(data_path_base, hp.test1.data_path)) - elif mode == "train2": - testset_size = hp.test2.batch_size * 4 - wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train2.data_path))[testset_size:] - elif mode == "test2": - testset_size = hp.test2.batch_size * 4 - wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train2.data_path))[:testset_size] - elif mode == "convert": - wav_files = glob.glob('{}/{}'.format(data_path_base, hp.convert.data_path)) + wav_files = glob.glob(getattr(hp, mode).data_path) + return wav_files diff --git a/datasets/arctic/.DS_Store b/datasets/arctic/.DS_Store index 5a93c133..afbd1241 100644 Binary files a/datasets/arctic/.DS_Store and b/datasets/arctic/.DS_Store differ diff --git a/eval1.py b/eval1.py index 7ff5b142..a47da8d5 100644 --- a/eval1.py +++ b/eval1.py @@ -6,15 +6,12 @@ import argparse import tensorflow as tf -from hparam import Hparam from data_load import get_batch -from hparam import logdir_path +from hparam import hparam as hp from models import Model def eval(logdir, writer, queue=False): - hp = Hparam.get_global_hparam() - # Load graph model = Model(mode="test1", batch_size=hp.test1.batch_size, hp=hp, queue=queue) @@ -70,10 +67,8 @@ def get_arguments(): if __name__ == '__main__': args = get_arguments() - case = args.case - logdir = '{}/{}/train1'.format(logdir_path, case) - Hparam(case).set_as_global_hparam() - + hp.set_hparam_yaml(args.case) + logdir = '{}/train1'.format(hp.logdir) writer = tf.summary.FileWriter(logdir) eval(logdir=logdir, writer=writer) writer.close() diff --git a/eval2.py b/eval2.py index bab4f716..48fdae98 100644 --- a/eval2.py +++ b/eval2.py @@ -7,12 +7,10 @@ from data_load import get_batch from models import Model import argparse -from hparam import logdir_path -from hparam import Hparam +from hparam import hparam as hp def eval(logdir, writer, queue=True): - hp = Hparam.get_global_hparam() # Load graph model = Model(mode="test2", batch_size=hp.test2.batch_size, hp=hp, queue=queue) @@ -63,9 +61,8 @@ def get_arguments(): if __name__ == '__main__': args = get_arguments() - case = args.case - logdir = '{}/{}/train2'.format(logdir_path, case) - Hparam(case).set_as_global_hparam() + hp.set_hparam_yaml(args.case) + logdir = '{}/train2'.format(hp.logdir) writer = tf.summary.FileWriter(logdir) eval(logdir=logdir, writer=writer) diff --git a/hparam.py b/hparam.py index 310e0637..05879daa 100644 --- a/hparam.py +++ b/hparam.py @@ -2,17 +2,6 @@ #!/usr/bin/env python import yaml -# import pprint - -# path -## local -data_path_base = './datasets' -# logdir_path = './logdir' - -## remote -# data_path_base = '/data/private/vc/datasets' -logdir_path = '/data/private/vc/logdir' - def load_hparam(filename): @@ -47,37 +36,35 @@ class Dotdict(dict): __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ - def __init__(self, dct): + def __init__(self, dct=None): + dct = dict() if not dct else dct for key, value in dct.items(): if hasattr(value, 'keys'): value = Dotdict(value) self[key] = value -class Hparam(): - default_hparam_file = 'hparams/default.yaml' - user_hparams_file = 'hparams/hparams.yaml' - global_hparam = None +class Hparam(Dotdict): + + def __init__(self): + super(Dotdict, self).__init__() - def __init__(self, case): - default_hp = load_hparam(Hparam.default_hparam_file) - user_hp = load_hparam(Hparam.user_hparams_file) - if case in user_hp: - hp = merge_dict(user_hp[case], default_hp) - else: - hp = default_hp - self.hparam = Dotdict(hp) + __getattr__ = Dotdict.__getitem__ + __setattr__ = Dotdict.__setitem__ + __delattr__ = Dotdict.__delitem__ - def __call__(self): - return self.hparam + def set_hparam_yaml(self, case, default_file='hparams/default.yaml', user_file='hparams/hparams.yaml'): + default_hp = load_hparam(default_file) + user_hp = load_hparam(user_file) + hp_dict = Dotdict(merge_dict(user_hp[case], default_hp) if case in user_hp else default_hp) + for k, v in hp_dict.items(): + setattr(self, k, v) + self._auto_setting(case) - def set_as_global_hparam(self): - Hparam.global_hparam = self.hparam - return Hparam.global_hparam + def _auto_setting(self, case): + setattr(self, 'case', case) - @staticmethod - def get_global_hparam(): - return Hparam.global_hparam + # logdir for a case is automatically set to [logdir_path]/[case] + setattr(self, 'logdir', '{}/{}'.format(hparam.logdir_path, case)) -# pp = pprint.PrettyPrinter(indent=4) -# pp.pprint(hparam) \ No newline at end of file +hparam = Hparam() \ No newline at end of file diff --git a/hparams/default.yaml b/hparams/default.yaml index 6918bc0f..8763d1b9 100644 --- a/hparams/default.yaml +++ b/hparams/default.yaml @@ -25,7 +25,9 @@ default: # train batch_size: 32 ---- + +logdir_path: '/data/private/vc/logdir' + train1: # path data_path: 'timit/TIMIT/TRAIN/*/*/*.wav' @@ -46,7 +48,7 @@ train1: --- train2: # path - data_path: 'arctic/slt/*.wav' + data_path: '/data/private/vc/datasets/arctic/slt/*.wav' # model hidden_units: 512 # alias: E @@ -69,18 +71,21 @@ train2: --- test1: # path - data_path: 'timit/TIMIT/TEST/*/*/*.wav' + data_path: '/data/private/vc/datasets/timit/TIMIT/TEST/*/*/*.wav' # test batch_size: 32 --- test2: + # path + data_path: '/data/private/vc/datasets/arctic/slt/*.wav' + # test batch_size: 32 --- convert: # path - data_path: 'arctic/bdl/*.wav' + data_path: '/data/private/vc/datasets/arctic/bdl/*.wav' # convert one_full_wav: False diff --git a/train1.py b/train1.py index 0e7d0ef6..169a3cba 100644 --- a/train1.py +++ b/train1.py @@ -2,7 +2,6 @@ # /usr/bin/python2 from __future__ import print_function -from hparam import logdir_path from tqdm import tqdm from modules import * @@ -10,11 +9,10 @@ import eval1 from data_load import get_batch import argparse -from hparam import Hparam +from hparam import hparam as hp def train(logdir, queue=True): - hp = Hparam.get_global_hparam() model = Model(mode="train1", batch_size=hp.train1.batch_size, hp=hp, queue=queue) @@ -88,9 +86,8 @@ def get_arguments(): if __name__ == '__main__': args = get_arguments() - case = args.case - logdir = '{}/{}/train1'.format(logdir_path, case) - Hparam(case).set_as_global_hparam() + hp.set_hparam_yaml(args.case) + logdir = '{}/train1'.format(hp.logdir) train(logdir=logdir) print("Done") \ No newline at end of file diff --git a/train2.py b/train2.py index 50d853be..3e745ae2 100644 --- a/train2.py +++ b/train2.py @@ -3,7 +3,6 @@ from __future__ import print_function -from hparam import logdir_path from tqdm import tqdm import tensorflow as tf @@ -11,14 +10,13 @@ import convert, eval2 from data_load import get_batch import argparse -from hparam import Hparam +from hparam import hparam as hp import math import os from utils import remove_all_files def train(logdir1, logdir2, queue=True): - hp = Hparam.get_global_hparam() model = Model(mode="train2", batch_size=hp.train2.batch_size, hp=hp, queue=queue) @@ -116,10 +114,9 @@ def get_arguments(): if __name__ == '__main__': args = get_arguments() - case1, case2 = args.case1, args.case2 - logdir1 = '{}/{}/train1'.format(logdir_path, case1) - logdir2 = '{}/{}/train2'.format(logdir_path, case2) - hp = Hparam(case2).set_as_global_hparam() + hp.set_hparam_yaml(args.case2) + logdir1 = '{}/{}/train1'.format(hp.logdir_path, args.case1) + logdir2 = '{}/train2'.format(hp.logdir) if args.r: ckpt = '{}/checkpoint'.format(os.path.join(logdir2)) @@ -128,7 +125,7 @@ def get_arguments(): remove_all_files(os.path.join(hp.logdir, 'events.out')) remove_all_files(os.path.join(hp.logdir, 'epoch_')) - print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(case1, case2, logdir1, logdir2)) + print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(args.case1, args.case2, logdir1, logdir2)) train(logdir1=logdir1, logdir2=logdir2)