Skip to content

Commit

Permalink
update hparam
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Feb 22, 2018
1 parent e978364 commit 3b891b0
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 105 deletions.
11 changes: 4 additions & 7 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
from models import Model
import numpy as np
from audio import spectrogram2wav, inv_preemphasis, db_to_amp
from hparam import logdir_path
import datetime
import tensorflow as tf
from hparam import Hparam
from hparam import hparam as hp
from utils import denormalize_0_1


def convert(logdir, writer, queue=False, step=None):
hp = Hparam.get_global_hparam()

# Load graph
model = Model(mode="convert", batch_size=hp.convert.batch_size, hp=hp, queue=queue)
Expand Down Expand Up @@ -107,11 +105,10 @@ def get_arguments():

if __name__ == '__main__':
args = get_arguments()
case = args.case
logdir = '{}/{}/train2'.format(logdir_path, case)
Hparam(case).set_as_global_hparam()
hp.set_hparam_yaml(args.case)
logdir = '{}/train2'.format(hp.logdir)

print('case: {}, logdir: {}'.format(case, logdir))
print('case: {}, logdir: {}'.format(args.case, logdir))

s = datetime.datetime.now()

Expand Down
35 changes: 3 additions & 32 deletions data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging

from hparam import Hparam
from audio import preemphasis, amp_to_db
import numpy as np
import librosa
from hparam import data_path_base
from hparam import hparam as hp
from utils import normalize_0_1


Expand All @@ -32,7 +31,6 @@ def wav_random_crop(wav, sr, duration):


def get_mfccs_and_phones(wav_file, sr, length, trim=False, random_crop=True):
hp = Hparam.get_global_hparam()

'''This is applied in `train1` or `test1` phase.
'''
Expand Down Expand Up @@ -84,7 +82,6 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat
'''This is applied in `train2`, `test2` or `convert` phase.
'''

hp = Hparam.get_global_hparam()

# Load
wav, _ = librosa.load(wav_file, sr=hp.default.sr)
Expand All @@ -106,7 +103,6 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat

def _get_mfcc_log_spec_and_log_mel_spec(wav, preemphasis_coeff, n_fft, win_length, hop_length):

hp = Hparam.get_global_hparam()

# Pre-emphasis
y_preem = preemphasis(wav, coeff=preemphasis_coeff)
Expand Down Expand Up @@ -195,7 +191,6 @@ def get_mfccs_and_phones_queue(wav_file):
extracts mfccs (inputs), and phones (target), then enqueue them again.
This is applied in `train1` or `test1` phase.
'''
hp = Hparam.get_global_hparam()

mfccs, phns = get_mfccs_and_phones(wav_file, hp.default.sr, length=int(hp.default.duration / hp.default.frame_shift + 1))
return mfccs, phns
Expand All @@ -207,7 +202,6 @@ def get_mfccs_and_spectrogram_queue(wav_file):
extracts mfccs and spectrogram, then enqueue them again.
This is applied in `train2` or `test2` phase.
'''
hp = Hparam.get_global_hparam()

mfccs, spec, mel = get_mfccs_and_spectrogram(wav_file, duration=hp.default.duration,
win_length=hp.default.win_length, hop_length=hp.default.hop_length)
Expand All @@ -218,7 +212,6 @@ def get_batch_queue(mode, batch_size):
'''Loads data and put them in mini batch queues.
mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`.
'''
hp = Hparam.get_global_hparam()

if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'):
raise Exception("invalid mode={}".format(mode))
Expand Down Expand Up @@ -273,7 +266,6 @@ def get_batch(mode, batch_size):
'''Loads data.
mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`.
'''
hp = Hparam.get_global_hparam()

if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'):
raise Exception("invalid mode={}".format(mode))
Expand All @@ -297,7 +289,6 @@ def get_batch(mode, batch_size):

# TODO generalize for all mode
def get_wav_batch(mode, batch_size):
hp = Hparam.get_global_hparam()

with tf.device('/cpu:0'):
# Load data
Expand Down Expand Up @@ -401,26 +392,6 @@ def load_vocab():


def load_data(mode):
'''Loads the list of sound files.
mode: A string. One of the phases below:
`train1`: TIMIT TRAIN waveform -> mfccs (inputs) -> PGGs -> phones (target) (ce loss)
`test1`: TIMIT TEST waveform -> mfccs (inputs) -> PGGs -> phones (target) (accuracy)
`train2`: ARCTIC SLT waveform -> mfccs -> PGGs (inputs) -> spectrogram (target)(l2 loss)
`test2`: ARCTIC SLT waveform -> mfccs -> PGGs (inputs) -> spectrogram (target)(accuracy)
`convert`: ARCTIC BDL waveform -> mfccs (inputs) -> PGGs -> spectrogram -> waveform (output)
'''
hp = Hparam.get_global_hparam()

if mode == "train1":
wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train1.data_path))
elif mode == "test1":
wav_files = glob.glob('{}/{}'.format(data_path_base, hp.test1.data_path))
elif mode == "train2":
testset_size = hp.test2.batch_size * 4
wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train2.data_path))[testset_size:]
elif mode == "test2":
testset_size = hp.test2.batch_size * 4
wav_files = glob.glob('{}/{}'.format(data_path_base, hp.train2.data_path))[:testset_size]
elif mode == "convert":
wav_files = glob.glob('{}/{}'.format(data_path_base, hp.convert.data_path))
wav_files = glob.glob(getattr(hp, mode).data_path)

return wav_files
Binary file modified datasets/arctic/.DS_Store
Binary file not shown.
11 changes: 3 additions & 8 deletions eval1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
import argparse

import tensorflow as tf
from hparam import Hparam
from data_load import get_batch
from hparam import logdir_path
from hparam import hparam as hp
from models import Model


def eval(logdir, writer, queue=False):
hp = Hparam.get_global_hparam()

# Load graph
model = Model(mode="test1", batch_size=hp.test1.batch_size, hp=hp, queue=queue)

Expand Down Expand Up @@ -70,10 +67,8 @@ def get_arguments():

if __name__ == '__main__':
args = get_arguments()
case = args.case
logdir = '{}/{}/train1'.format(logdir_path, case)
Hparam(case).set_as_global_hparam()

hp.set_hparam_yaml(args.case)
logdir = '{}/train1'.format(hp.logdir)
writer = tf.summary.FileWriter(logdir)
eval(logdir=logdir, writer=writer)
writer.close()
Expand Down
9 changes: 3 additions & 6 deletions eval2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from data_load import get_batch
from models import Model
import argparse
from hparam import logdir_path
from hparam import Hparam
from hparam import hparam as hp


def eval(logdir, writer, queue=True):
hp = Hparam.get_global_hparam()

# Load graph
model = Model(mode="test2", batch_size=hp.test2.batch_size, hp=hp, queue=queue)
Expand Down Expand Up @@ -63,9 +61,8 @@ def get_arguments():

if __name__ == '__main__':
args = get_arguments()
case = args.case
logdir = '{}/{}/train2'.format(logdir_path, case)
Hparam(case).set_as_global_hparam()
hp.set_hparam_yaml(args.case)
logdir = '{}/train2'.format(hp.logdir)

writer = tf.summary.FileWriter(logdir)
eval(logdir=logdir, writer=writer)
Expand Down
55 changes: 21 additions & 34 deletions hparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@
#!/usr/bin/env python

import yaml
# import pprint

# path
## local
data_path_base = './datasets'
# logdir_path = './logdir'

## remote
# data_path_base = '/data/private/vc/datasets'
logdir_path = '/data/private/vc/logdir'



def load_hparam(filename):
Expand Down Expand Up @@ -47,37 +36,35 @@ class Dotdict(dict):
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

def __init__(self, dct):
def __init__(self, dct=None):
dct = dict() if not dct else dct
for key, value in dct.items():
if hasattr(value, 'keys'):
value = Dotdict(value)
self[key] = value


class Hparam():
default_hparam_file = 'hparams/default.yaml'
user_hparams_file = 'hparams/hparams.yaml'
global_hparam = None
class Hparam(Dotdict):

def __init__(self):
super(Dotdict, self).__init__()

def __init__(self, case):
default_hp = load_hparam(Hparam.default_hparam_file)
user_hp = load_hparam(Hparam.user_hparams_file)
if case in user_hp:
hp = merge_dict(user_hp[case], default_hp)
else:
hp = default_hp
self.hparam = Dotdict(hp)
__getattr__ = Dotdict.__getitem__
__setattr__ = Dotdict.__setitem__
__delattr__ = Dotdict.__delitem__

def __call__(self):
return self.hparam
def set_hparam_yaml(self, case, default_file='hparams/default.yaml', user_file='hparams/hparams.yaml'):
default_hp = load_hparam(default_file)
user_hp = load_hparam(user_file)
hp_dict = Dotdict(merge_dict(user_hp[case], default_hp) if case in user_hp else default_hp)
for k, v in hp_dict.items():
setattr(self, k, v)
self._auto_setting(case)

def set_as_global_hparam(self):
Hparam.global_hparam = self.hparam
return Hparam.global_hparam
def _auto_setting(self, case):
setattr(self, 'case', case)

@staticmethod
def get_global_hparam():
return Hparam.global_hparam
# logdir for a case is automatically set to [logdir_path]/[case]
setattr(self, 'logdir', '{}/{}'.format(hparam.logdir_path, case))

# pp = pprint.PrettyPrinter(indent=4)
# pp.pprint(hparam)
hparam = Hparam()
13 changes: 9 additions & 4 deletions hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ default:

# train
batch_size: 32
---

logdir_path: '/data/private/vc/logdir'

train1:
# path
data_path: 'timit/TIMIT/TRAIN/*/*/*.wav'
Expand All @@ -46,7 +48,7 @@ train1:
---
train2:
# path
data_path: 'arctic/slt/*.wav'
data_path: '/data/private/vc/datasets/arctic/slt/*.wav'

# model
hidden_units: 512 # alias: E
Expand All @@ -69,18 +71,21 @@ train2:
---
test1:
# path
data_path: 'timit/TIMIT/TEST/*/*/*.wav'
data_path: '/data/private/vc/datasets/timit/TIMIT/TEST/*/*/*.wav'

# test
batch_size: 32
---
test2:
# path
data_path: '/data/private/vc/datasets/arctic/slt/*.wav'

# test
batch_size: 32
---
convert:
# path
data_path: 'arctic/bdl/*.wav'
data_path: '/data/private/vc/datasets/arctic/bdl/*.wav'

# convert
one_full_wav: False
Expand Down
9 changes: 3 additions & 6 deletions train1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
# /usr/bin/python2

from __future__ import print_function
from hparam import logdir_path
from tqdm import tqdm

from modules import *
from models import Model
import eval1
from data_load import get_batch
import argparse
from hparam import Hparam
from hparam import hparam as hp


def train(logdir, queue=True):
hp = Hparam.get_global_hparam()

model = Model(mode="train1", batch_size=hp.train1.batch_size, hp=hp, queue=queue)

Expand Down Expand Up @@ -88,9 +86,8 @@ def get_arguments():

if __name__ == '__main__':
args = get_arguments()
case = args.case
logdir = '{}/{}/train1'.format(logdir_path, case)
Hparam(case).set_as_global_hparam()
hp.set_hparam_yaml(args.case)
logdir = '{}/train1'.format(hp.logdir)

train(logdir=logdir)
print("Done")
13 changes: 5 additions & 8 deletions train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@

from __future__ import print_function

from hparam import logdir_path
from tqdm import tqdm

import tensorflow as tf
from models import Model
import convert, eval2
from data_load import get_batch
import argparse
from hparam import Hparam
from hparam import hparam as hp
import math
import os
from utils import remove_all_files


def train(logdir1, logdir2, queue=True):
hp = Hparam.get_global_hparam()

model = Model(mode="train2", batch_size=hp.train2.batch_size, hp=hp, queue=queue)

Expand Down Expand Up @@ -116,10 +114,9 @@ def get_arguments():

if __name__ == '__main__':
args = get_arguments()
case1, case2 = args.case1, args.case2
logdir1 = '{}/{}/train1'.format(logdir_path, case1)
logdir2 = '{}/{}/train2'.format(logdir_path, case2)
hp = Hparam(case2).set_as_global_hparam()
hp.set_hparam_yaml(args.case2)
logdir1 = '{}/{}/train1'.format(hp.logdir_path, args.case1)
logdir2 = '{}/train2'.format(hp.logdir)

if args.r:
ckpt = '{}/checkpoint'.format(os.path.join(logdir2))
Expand All @@ -128,7 +125,7 @@ def get_arguments():
remove_all_files(os.path.join(hp.logdir, 'events.out'))
remove_all_files(os.path.join(hp.logdir, 'epoch_'))

print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(case1, case2, logdir1, logdir2))
print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(args.case1, args.case2, logdir1, logdir2))

train(logdir1=logdir1, logdir2=logdir2)

Expand Down

0 comments on commit 3b891b0

Please sign in to comment.