Skip to content

Commit

Permalink
apply tensorpack in train2
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Mar 7, 2018
1 parent 3184a62 commit 9fb6cbf
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 126 deletions.
43 changes: 31 additions & 12 deletions data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@
from utils import normalize_0_1


def load_data(mode):
wav_files = glob.glob(getattr(hp, mode).data_path)
class Net1DataFlow(RNGDataFlow):

return wav_files
def __init__(self, data_path, batch_size):
self.batch_size = batch_size
self.wav_files = glob.glob(data_path)

def __call__(self, n_prefetch=1000, n_thread=1):
df = self
if self.batch_size > 1:
df = BatchData(df, self.batch_size)
df = PrefetchData(df, n_prefetch, n_thread)
return df

class Net1DataFlow(RNGDataFlow):
def get_data(self):
while True:
wav_file = random.choice(self.wav_files)
yield get_mfccs_and_phones(wav_file=wav_file)


class Net2DataFlow(RNGDataFlow):

def __init__(self, data_path, batch_size):
self.batch_size = batch_size
Expand All @@ -36,7 +49,13 @@ def __call__(self, n_prefetch=1000, n_thread=1):
def get_data(self):
while True:
wav_file = random.choice(self.wav_files)
yield get_mfccs_and_phones(wav_file=wav_file)
yield get_mfccs_and_spectrogram(wav_file)


def load_data(mode):
wav_files = glob.glob(getattr(hp, mode).data_path)

return wav_files


def wav_random_crop(wav, sr, duration):
Expand Down Expand Up @@ -102,7 +121,7 @@ def get_mfccs_and_phones(wav_file, trim=False, random_crop=True):
return mfccs, phns


def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, duration=None, random_crop=False):
def get_mfccs_and_spectrogram(wav_file, trim=True, random_crop=False):
'''This is applied in `train2`, `test2` or `convert` phase.
'''

Expand All @@ -112,19 +131,19 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat

# Trim
if trim:
wav, _ = librosa.effects.trim(wav, frame_length=win_length, hop_length=hop_length)
wav, _ = librosa.effects.trim(wav, frame_length=hp.default.win_length, hop_length=hp.default.hop_length)

if random_crop:
wav = wav_random_crop(wav, hp.default.sr, duration)
wav = wav_random_crop(wav, hp.default.sr, hp.default.duration)

# Padding or crop
if duration:
length = hp.default.sr * duration
wav = librosa.util.fix_length(wav, length)
length = hp.default.sr * hp.default.duration
wav = librosa.util.fix_length(wav, length)

return _get_mfcc_and_spec(wav, hp.default.preemphasis, hp.default.n_fft, win_length, hop_length)
return _get_mfcc_and_spec(wav, hp.default.preemphasis, hp.default.n_fft, hp.default.win_length, hp.default.hop_length)


# TODO refactoring
def _get_mfcc_and_spec(wav, preemphasis_coeff, n_fft, win_length, hop_length):

# Pre-emphasis
Expand Down
54 changes: 37 additions & 17 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def _get_inputs(self):

def _build_graph(self, inputs):
self.x_mfccs, self.y_ppgs = inputs
self.is_training = get_current_tower_context().is_training
with tf.variable_scope('net1'):
self.ppgs, self.preds, self.logits = self.network(self.x_mfccs)
self.ppgs, self.preds, self.logits = self.network(self.x_mfccs, get_current_tower_context().is_training)
self.cost = self.loss()

# summaries
Expand All @@ -36,16 +35,16 @@ def _get_optimizer(self):
return tf.train.AdamOptimizer(lr)

@auto_reuse_variable_scope
def network(self, x_mfcc):
def network(self, x_mfcc, is_training):
# Pre-net
prenet_out = prenet(x_mfcc,
num_units=[hp.train1.hidden_units, hp.train1.hidden_units // 2],
dropout_rate=hp.train1.dropout_rate,
is_training=self.is_training) # (N, T, E/2)
is_training=is_training) # (N, T, E/2)

# CBHG
out = cbhg(prenet_out, hp.train1.num_banks, hp.train1.hidden_units // 2,
hp.train1.num_highway_blocks, hp.train1.norm_type, self.is_training)
hp.train1.num_highway_blocks, hp.train1.norm_type, is_training)

# Final linear projection
logits = tf.layers.dense(out, len(phns)) # (N, T, V)
Expand All @@ -69,52 +68,73 @@ def acc(self):
acc = num_hits / num_targets
return acc


class Net2(ModelDesc):

def __init__(self):
self.net1 = Net1()

def _get_inputs(self):
return [InputDesc(tf.float32, (hp.train2.batch_size, None, hp.default.n_mfcc), 'x_mfccs'),
InputDesc(tf.float32, (hp.train2.batch_size, None, hp.default.n_mels), 'y_mel'),
InputDesc(tf.float32, (hp.train2.batch_size, None, hp.default.n_fft//2 + 1), 'y_spec')]
InputDesc(tf.float32, (hp.train2.batch_size, None, hp.default.n_fft // 2 + 1), 'y_spec'),
InputDesc(tf.float32, (hp.train2.batch_size, None, hp.default.n_mels), 'y_mel'),]

def _build_graph(self, inputs):
self.x_mfcc, self.y_mel, self.y_spec = inputs
self.x_mfcc, self.y_spec, self.y_mel = inputs

is_training = get_current_tower_context().is_training

# build net1
with tf.variable_scope('net1'):
self.ppgs, _, _ = self.net1.network(self.x_mfcc)
self.ppgs, _, _ = self.net1.network(self.x_mfcc, is_training)

# build net2
with tf.variable_scope('net2'):
self.pred_spec_mu, self.pred_spec_phi = self.network(self.ppgs)
self.pred_spec_mu, self.pred_spec_phi = self.network(self.ppgs, is_training)

self.cost = self.loss()

def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=hp.train2.lr, trainable=False)
return tf.train.AdamOptimizer(lr)
# summaries
tf.summary.scalar('net2/train/loss', self.cost)
# tf.summary.scalar('net2/train/lr', lr)
tf.summary.histogram('net2/train/mu', self.pred_spec_mu)
tf.summary.histogram('net2/train/phi', self.pred_spec_phi)
# TODO remove
tf.summary.scalar('net2/prob_min', self.prob_min)

# def _get_train_op(self):
# lr = tf.get_variable('learning_rate', initializer=hp.train2.lr, trainable=False)
# optimizer = tf.train.AdamOptimizer(learning_rate=lr)
# with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
# var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net2')
#
# # Gradient clipping to prevent loss explosion
# gvs = optimizer.compute_gradients(loss_op, var_list=var_list)
# gvs = [(tf.clip_by_value(grad, hp.train2.clip_value_min, hp.train2.clip_value_max), var) for grad, var in
# gvs]
# gvs = [(tf.clip_by_norm(grad, hp.train2.clip_norm), var) for grad, var in gvs]
#
# return optimizer.apply_gradients(gvs)

@auto_reuse_variable_scope
def network(self, ppgs):
def network(self, ppgs, is_training):
# Pre-net
prenet_out = prenet(ppgs,
num_units=[hp.train2.hidden_units, hp.train2.hidden_units // 2],
dropout_rate=hp.train2.dropout_rate,
is_training=self.is_training) # (N, T, E/2)
is_training=is_training) # (N, T, E/2)

# CBHG1: mel-scale
# pred_mel = cbhg(prenet_out, hp.train2.num_banks, hp.train2.hidden_units // 2,
# hp.train2.num_highway_blocks, hp.train2.norm_type, self.is_training,
# hp.train2.num_highway_blocks, hp.train2.norm_type, is_training,
# scope="cbhg_mel")
# pred_mel = tf.layers.dense(pred_mel, self.y_mel.shape[-1]) # (N, T, n_mels)
pred_mel = prenet_out

# CBHG2: linear-scale
pred_spec = tf.layers.dense(pred_mel, hp.train2.hidden_units // 2) # (N, T, n_mels)
pred_spec = cbhg(pred_spec, hp.train2.num_banks, hp.train2.hidden_units // 2,
hp.train2.num_highway_blocks, hp.train2.norm_type, self.is_training,
hp.train2.num_highway_blocks, hp.train2.norm_type, is_training,
scope="cbhg_linear")
pred_spec = tf.layers.dense(pred_spec, self.y_spec.shape[-1]) # (N, T, 1+hp.n_fft//2)
pred_spec = tf.expand_dims(pred_spec, axis=-1)
Expand Down
10 changes: 8 additions & 2 deletions train1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from models import Net1
import tensorflow as tf


def train(args, logdir):

# model
model = Net1()

Expand Down Expand Up @@ -63,12 +65,16 @@ def get_arguments():
parser.add_argument('case', type=str, help='experiment case name')
parser.add_argument('-ckpt', help='checkpoint to load model.')
parser.add_argument('-gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-r', action='store_true', help='start training from the beginning.')
arguments = parser.parse_args()
return arguments

if __name__ == '__main__':
args = get_arguments()
hp.set_hparam_yaml(args.case)
train(args, logdir='{}/train1'.format(hp.logdir))
logdir_train1 = '{}/train1'.format(hp.logdir)

print('case: {}, logdir: {}'.format(args.case1, args.case, logdir_train1))

train(args, logdir=logdir_train1)

print("Done")
Loading

0 comments on commit 9fb6cbf

Please sign in to comment.