Skip to content

Commit

Permalink
apply cyclic learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Dec 2, 2017
1 parent 5d27586 commit d67b308
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 54 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ logdir*
datasets
.DS_Store
outputs
__*
__*
hparams/hparams.yaml
tools
3 changes: 3 additions & 0 deletions hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ train2:
# train
batch_size: 32
lr: 0.0005
lr_cyclic_margin: 0.0004
lr_cyclic_steps: 5000
clip_value_max: 3.
clip_value_min: -3.
clip_norm: 100
num_epochs: 10000
save_per_epoch: 50
---
Expand Down
50 changes: 37 additions & 13 deletions train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from data_load import get_batch
import argparse
from hparam import Hparam
import math
import os
from utils import remove_all_files


def train(logdir1, logdir2, queue=True):
Expand All @@ -26,17 +29,24 @@ def train(logdir1, logdir2, queue=True):
epoch, gs = Model.get_epoch_and_global_step(logdir2)
global_step = tf.Variable(gs, name='global_step', trainable=False)

optimizer = tf.train.AdamOptimizer(learning_rate=hp.train2.lr)
lr = tf.placeholder(tf.float32, shape=[])
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net2')

# Gradient clipping to prevent loss explosion
gvs = optimizer.compute_gradients(loss_op, var_list=var_list)
clipped_gvs = [(tf.clip_by_value(grad, hp.train2.clip_value_min, hp.train2.clip_value_max), var) for grad, var in gvs]
train_op = optimizer.apply_gradients(clipped_gvs, global_step=global_step)
gvs = [(tf.clip_by_value(grad, hp.train2.clip_value_min, hp.train2.clip_value_max), var) for grad, var in gvs]
gvs = [(tf.clip_by_norm(grad, hp.train2.clip_norm), var) for grad, var in gvs]

train_op = optimizer.apply_gradients(gvs, global_step=global_step)

# Summary
summ_op = summaries(loss_op)
# for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net2'):
# tf.summary.histogram(v.name, v)
tf.summary.scalar('net2/train/loss', loss_op)
tf.summary.scalar('net2/train/lr', lr)
summ_op = tf.summary.merge_all()

session_conf = tf.ConfigProto(
gpu_options=tf.GPUOptions(
Expand All @@ -56,14 +66,21 @@ def train(logdir1, logdir2, queue=True):

for epoch in range(epoch + 1, hp.train2.num_epochs + 1):
for _ in tqdm(range(model.num_batch), total=model.num_batch, ncols=70, leave=False, unit='b'):
gs = sess.run(global_step)

# Cyclic learning rate
feed_dict = {lr: get_cyclic_lr(gs)}
if queue:
sess.run(train_op)
sess.run(train_op, feed_dict=feed_dict)
else:
mfcc, spec, mel = get_batch(model.mode, model.batch_size)
sess.run(train_op, feed_dict={model.x_mfcc: mfcc, model.y_spec: spec, model.y_mel: mel})
feed_dict.update({model.x_mfcc: mfcc, model.y_spec: spec, model.y_mel: mel})
sess.run(train_op, feed_dict=feed_dict)

# Write checkpoint files at every epoch
summ, gs = sess.run([summ_op, global_step])
gs = sess.run(global_step)
feed_dict = {lr: get_cyclic_lr(gs)}
summ = sess.run(summ_op, feed_dict=feed_dict)

if epoch % hp.train2.save_per_epoch == 0:
tf.train.Saver().save(sess, '{}/epoch_{}_step_{}'.format(logdir2, epoch, gs))
Expand All @@ -83,17 +100,17 @@ def train(logdir1, logdir2, queue=True):
coord.join(threads)


def summaries(loss):
# for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net/net2'):
# tf.summary.histogram(v.name, v)
tf.summary.scalar('net2/train/loss', loss)
return tf.summary.merge_all()
def get_cyclic_lr(step):
lr_margin = hp.train2.lr_cyclic_margin * math.sin(2. * math.pi / hp.train2.lr_cyclic_steps * step)
lr = hp.train2.lr + lr_margin
return lr


def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('case1', type=str, help='experiment case name of train1')
parser.add_argument('case2', type=str, help='experiment case name of train2')
parser.add_argument('-r', action='store_true', help='start training from the beginning.')
arguments = parser.parse_args()
return arguments

Expand All @@ -102,7 +119,14 @@ def get_arguments():
case1, case2 = args.case1, args.case2
logdir1 = '{}/{}/train1'.format(logdir_path, case1)
logdir2 = '{}/{}/train2'.format(logdir_path, case2)
Hparam(case2).set_as_global_hparam()
hp = Hparam(case2).set_as_global_hparam()

if args.r:
ckpt = '{}/checkpoint'.format(os.path.join(logdir2))
if os.path.exists(ckpt):
os.remove(ckpt)
remove_all_files(os.path.join(hp.logdir, 'events.out'))
remove_all_files(os.path.join(hp.logdir, 'epoch_'))

print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(case1, case2, logdir1, logdir2))

Expand Down
55 changes: 15 additions & 40 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,21 @@
# -*- coding: utf-8 -*-
#/usr/bin/python2
#!/usr/bin/env python

from __future__ import print_function
import os

import librosa
import numpy as np
from scipy import signal

def split_path(path):
'''
'a/b/c.wav' => ('a/b', 'c', 'wav')
:param path: filepath = 'a/b/c.wav'
:return: basename, filename, and extension = ('a/b', 'c', 'wav')
'''
basepath, filename = os.path.split(path)
filename, extension = os.path.splitext(filename)
return basepath, filename, extension

def spectrogram2wav(mag, n_fft, win_length, hop_length, num_iters, phase_angle=None, length=None):
assert(num_iters > 0)
if phase_angle is None:
phase_angle = np.pi * np.random.rand(*mag.shape)
spec = mag * np.exp(1.j * phase_angle)
for i in range(num_iters):
wav = librosa.istft(spec, win_length=win_length, hop_length=hop_length, length=length)
if i != num_iters - 1:
spec = librosa.stft(wav, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
_, phase = librosa.magphase(spec)
phase_angle = np.angle(phase)
spec = mag * np.exp(1.j * phase_angle)
return wav


def preemphasis(x, coeff=0.97):
return signal.lfilter([1, -coeff], [1], x)


def inv_preemphasis(x, coeff=0.97):
return signal.lfilter([1], [1, -coeff], x)


def wav_random_crop(wav, sr, duration):
assert (wav.ndim <= 2)

target_len = sr * duration
wav_len = wav.shape[-1]
start = np.random.choice(range(np.maximum(1, wav_len - target_len)), 1)[0]
end = start + target_len
if wav.ndim == 1:
wav = wav[start:end]
else:
wav = wav[:, start:end]
return wav

def remove_all_files(prefix):
files = glob.glob(prefix + '*')
for f in files:
os.remove(f)

0 comments on commit d67b308

Please sign in to comment.