Skip to content

Commit

Permalink
update convert, add logvar, fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Mar 22, 2018
1 parent 94ec076 commit a4a12fd
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 117 deletions.
161 changes: 90 additions & 71 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,115 +6,134 @@

import argparse

from data_load import get_wav_batch, get_batch
from models import Model
from models import Net2
import numpy as np
from audio import spectrogram2wav, inv_preemphasis, db_to_amp
from audio import spec2wav, inv_preemphasis, db2amp, denormalize_db
import datetime
import tensorflow as tf
from hparam import hparam as hp
from utils import denormalize_0_1
from data_load import Net2DataFlow
from tensorpack.predict.base import OfflinePredictor
from tensorpack.predict.config import PredictConfig
from tensorpack.tfutils.sessinit import SaverRestore
from tensorpack.tfutils.sessinit import ChainInit


def convert(logdir, writer, queue=False, step=None):
def convert(model, mfccs, spec, mel, ckpt1=None, ckpt2=None):
session_inits = []
if ckpt2:
session_inits.append(SaverRestore(ckpt2))
if ckpt1:
session_inits.append(SaverRestore(ckpt1, ignore=['global_step']))

# Load graph
model = Model(mode="convert", batch_size=hp.convert.batch_size, queue=queue)
pred_conf = PredictConfig(
model=model,
input_names=get_eval_input_names(),
output_names=get_eval_output_names(),
session_init=ChainInit(session_inits))
predict_spec = OfflinePredictor(pred_conf)

pred_spec, y_spec, ppgs = predict_spec(mfccs, spec, mel)

return pred_spec, y_spec, ppgs


def get_eval_input_names():
return ['x_mfccs', 'y_spec', 'y_mel']

session_conf = tf.ConfigProto(
allow_soft_placement=True,
device_count={'CPU': 1, 'GPU': 0},
gpu_options=tf.GPUOptions(
allow_growth=True,
per_process_gpu_memory_fraction=0.6
),
)
with tf.Session(config=session_conf) as sess:
# Load trained model
sess.run(tf.global_variables_initializer())
model.load(sess, 'convert', logdir=logdir, step=step)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
def get_eval_output_names():
return ['pred_spec', 'y_spec', 'net1/ppgs']

epoch, gs = Model.get_epoch_and_global_step(logdir, step=step)

if queue:
pred_spec, y_spec, ppgs = sess.run([model(), model.y_spec, model.ppgs])
else:
if hp.convert.one_full_wav:
mfcc, spec, mel = get_wav_batch(model.mode, model.batch_size)
else:
mfcc, spec, mel = get_batch(model.mode, model.batch_size)
def do_convert(args, logdir1, logdir2):

pred_spec, y_spec, ppgs = sess.run([model(), model.y_spec, model.ppgs], feed_dict={model.x_mfcc: mfcc, model.y_spec: spec, model.y_mel: mel})
# Load graph
model = Net2(batch_size=hp.convert.batch_size)

df = Net2DataFlow(hp.convert.data_path, hp.convert.batch_size)

# De-quantization
# bins = np.linspace(0, 1, hp.default.quantize_db)
# y_spec = bins[y_spec]
# samples
mfccs, spec, mel = df().get_data().next()

# Denormalizatoin
pred_spec = denormalize_0_1(pred_spec, hp.default.max_db, hp.default.min_db)
y_spec = denormalize_0_1(y_spec, hp.default.max_db, hp.default.min_db)
ckpt1 = tf.train.latest_checkpoint(logdir1)
ckpt2 = '{}/{}'.format(logdir2, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2)

# Db to amp
pred_spec = db_to_amp(pred_spec)
y_spec = db_to_amp(y_spec)
pred_spec, y_spec, ppgs = convert(model, mfccs, spec, mel, ckpt1, ckpt2)
print(np.max(pred_spec))
print(np.min(pred_spec))

# Convert log of magnitude to magnitude
# pred_specs, y_specs = np.e ** pred_specs, np.e ** y_spec
# Denormalizatoin
pred_spec = denormalize_db(pred_spec, hp.default.max_db, hp.default.min_db)
y_spec = denormalize_db(y_spec, hp.default.max_db, hp.default.min_db)

# Emphasize the magnitude
pred_spec = np.power(pred_spec, hp.convert.emphasis_magnitude)
y_spec = np.power(y_spec, hp.convert.emphasis_magnitude)
# Db to amp
pred_spec = db2amp(pred_spec)
y_spec = db2amp(y_spec)

# Spectrogram to waveform
audio = np.array(map(lambda spec: spectrogram2wav(spec.T, hp.default.n_fft, hp.default.win_length, hp.default.hop_length, hp.default.n_iter), pred_spec))
y_audio = np.array(map(lambda spec: spectrogram2wav(spec.T, hp.default.n_fft, hp.default.win_length, hp.default.hop_length, hp.default.n_iter), y_spec))
# Emphasize the magnitude
pred_spec = np.power(pred_spec, hp.convert.emphasis_magnitude)
y_spec = np.power(y_spec, hp.convert.emphasis_magnitude)

# Apply inverse pre-emphasis
audio = inv_preemphasis(audio, coeff=hp.default.preemphasis)
y_audio = inv_preemphasis(y_audio, coeff=hp.default.preemphasis)
# Spectrogram to waveform
audio = np.array(map(lambda spec: spec2wav(spec.T, hp.default.n_fft, hp.default.win_length, hp.default.hop_length,
hp.default.n_iter), pred_spec))
y_audio = np.array(map(lambda spec: spec2wav(spec.T, hp.default.n_fft, hp.default.win_length, hp.default.hop_length,
hp.default.n_iter), y_spec))

if hp.convert.one_full_wav:
# Concatenate to a wav
y_audio = np.reshape(y_audio, (1, y_audio.size), order='C')
audio = np.reshape(audio, (1, audio.size), order='C')
# Apply inverse pre-emphasis
audio = inv_preemphasis(audio, coeff=hp.default.preemphasis)
y_audio = inv_preemphasis(y_audio, coeff=hp.default.preemphasis)

# Write the result
tf.summary.audio('A', y_audio, hp.default.sr, max_outputs=hp.convert.batch_size)
tf.summary.audio('B', audio, hp.default.sr, max_outputs=hp.convert.batch_size)
if hp.convert.one_full_wav:
# Concatenate to a wav
y_audio = np.reshape(y_audio, (1, y_audio.size), order='C')
audio = np.reshape(audio, (1, audio.size), order='C')

# Visualize PPGs
heatmap = np.expand_dims(ppgs, 3) # channel=1
tf.summary.image('PPG', heatmap, max_outputs=ppgs.shape[0])
# Write the result
tf.summary.audio('A', y_audio, hp.default.sr, max_outputs=hp.convert.batch_size)
tf.summary.audio('B', audio, hp.default.sr, max_outputs=hp.convert.batch_size)

writer.add_summary(sess.run(tf.summary.merge_all()), global_step=gs)
# Visualize PPGs
heatmap = np.expand_dims(ppgs, 3) # channel=1
tf.summary.image('PPG', heatmap, max_outputs=ppgs.shape[0])

coord.request_stop()
coord.join(threads)
writer = tf.summary.FileWriter(logdir2)
with tf.Session() as sess:
summ = sess.run(tf.summary.merge_all())
writer.add_summary(summ)
writer.close()

# session_conf = tf.ConfigProto(
# allow_soft_placement=True,
# device_count={'CPU': 1, 'GPU': 0},
# gpu_options=tf.GPUOptions(
# allow_growth=True,
# per_process_gpu_memory_fraction=0.6
# ),
# )


def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('case', type=str, help='experiment case name')
parser.add_argument('-step', type=int, help='checkpoint step to load', nargs='?', default=None)
parser.add_argument('case1', type=str, help='experiment case name of train1')
parser.add_argument('case2', type=str, help='experiment case name of train2')
parser.add_argument('-ckpt', help='checkpoint to load model.')
arguments = parser.parse_args()
return arguments


if __name__ == '__main__':
args = get_arguments()
hp.set_hparam_yaml(args.case)
logdir = '{}/train2'.format(hp.logdir)
hp.set_hparam_yaml(args.case2)
logdir_train1 = '{}/{}/train1'.format(hp.logdir_path, args.case1)
logdir_train2 = '{}/{}/train2'.format(hp.logdir_path, args.case2)

print('case: {}, logdir: {}'.format(args.case, logdir))
print('case1: {}, case2: {}, logdir1: {}, logdir2: {}'.format(args.case1, args.case2, logdir_train1, logdir_train2))

s = datetime.datetime.now()

writer = tf.summary.FileWriter(logdir)
convert(logdir=logdir, writer=writer, step=args.step)
writer.close()
do_convert(args, logdir1=logdir_train1, logdir2=logdir_train2)

e = datetime.datetime.now()
diff = e - s
Expand Down
18 changes: 5 additions & 13 deletions data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils import normalize_0_1


class Net1DataFlow(RNGDataFlow):
class DataFlow(RNGDataFlow):

def __init__(self, data_path, batch_size):
self.batch_size = batch_size
Expand All @@ -27,24 +27,16 @@ def __call__(self, n_prefetch=1000, n_thread=1):
df = PrefetchData(df, n_prefetch, n_thread)
return df


class Net1DataFlow(DataFlow):

def get_data(self):
while True:
wav_file = random.choice(self.wav_files)
yield get_mfccs_and_phones(wav_file=wav_file)


class Net2DataFlow(RNGDataFlow):

def __init__(self, data_path, batch_size):
self.batch_size = batch_size
self.wav_files = glob.glob(data_path)

def __call__(self, n_prefetch=1000, n_thread=1):
df = self
if self.batch_size > 1:
df = BatchData(df, self.batch_size)
df = PrefetchData(df, n_prefetch, n_thread)
return df
class Net2DataFlow(DataFlow):

def get_data(self):
while True:
Expand Down
10 changes: 5 additions & 5 deletions hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ train2:

# train
batch_size: 32
lr: 0.0003
lr: 0.0001
lr_cyclic_margin: 0.
lr_cyclic_steps: 5000
clip_value_max: 3.
clip_value_min: -3.
clip_norm: 10
mol_step: 0.001
num_epochs: 10000
steps_per_epoch: 100
num_epochs: 100000
steps_per_epoch: 10
save_per_epoch: 50
num_gpu: 4
---
Expand All @@ -94,5 +94,5 @@ convert:

# convert
one_full_wav: False
batch_size: 6
emphasis_magnitude: 1.2
batch_size: 3
emphasis_magnitude: 1.5
17 changes: 17 additions & 0 deletions hparams/hparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@ default:
# lr: 0.0001
steps_per_epoch: 1
---
mix_1:
train2:
num_mixtures: 1
---
mix_2:
train2:
num_mixtures: 2
---
mix_10:
train2:
num_mixtures: 10
---
mix_5_256:
train2:
hidden_units: 256
num_mixtures: 10
---
mol_iu:
train2:
data_path: '/data/private/vc/datasets/IU/*_split/*.wav'
Expand Down
Loading

0 comments on commit a4a12fd

Please sign in to comment.