Skip to content

Commit

Permalink
major update: apply tensorpack
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Mar 3, 2018
1 parent 85d500a commit 3184a62
Show file tree
Hide file tree
Showing 10 changed files with 778 additions and 561 deletions.
371 changes: 338 additions & 33 deletions audio.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def convert(logdir, writer, queue=False, step=None):

# Load graph
model = Model(mode="convert", batch_size=hp.convert.batch_size, hp=hp, queue=queue)
model = Model(mode="convert", batch_size=hp.convert.batch_size, queue=queue)

session_conf = tf.ConfigProto(
allow_soft_placement=True,
Expand Down
323 changes: 45 additions & 278 deletions data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,43 @@
# /usr/bin/python2

import glob
import threading
from functools import wraps
from random import sample
import random

import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging

from audio import preemphasis, amp_to_db
import numpy as np
import librosa
import numpy as np
from tensorpack.dataflow.base import RNGDataFlow
from tensorpack.dataflow.common import BatchData
from tensorpack.dataflow.prefetch import PrefetchData
from audio import read_wav, preemphasis, amp2db
from hparam import hparam as hp
from utils import normalize_0_1


def load_data(mode):
wav_files = glob.glob(getattr(hp, mode).data_path)

return wav_files


class Net1DataFlow(RNGDataFlow):

def __init__(self, data_path, batch_size):
self.batch_size = batch_size
self.wav_files = glob.glob(data_path)

def __call__(self, n_prefetch=1000, n_thread=1):
df = self
if self.batch_size > 1:
df = BatchData(df, self.batch_size)
df = PrefetchData(df, n_prefetch, n_thread)
return df

def get_data(self):
while True:
wav_file = random.choice(self.wav_files)
yield get_mfccs_and_phones(wav_file=wav_file)


def wav_random_crop(wav, sr, duration):
assert (wav.ndim <= 2)

Expand All @@ -30,17 +53,17 @@ def wav_random_crop(wav, sr, duration):
return wav


def get_mfccs_and_phones(wav_file, sr, length, trim=False, random_crop=True):
def get_mfccs_and_phones(wav_file, trim=False, random_crop=True):

'''This is applied in `train1` or `test1` phase.
'''

# Load
wav, sr = librosa.load(wav_file, sr=sr)
wav = read_wav(wav_file, sr=hp.default.sr)

mfccs, _, _ = _get_mfcc_log_spec_and_log_mel_spec(wav, hp.default.preemphasis, hp.default.n_fft,
hp.default.win_length,
hp.default.hop_length)
mfccs, _, _ = _get_mfcc_and_spec(wav, hp.default.preemphasis, hp.default.n_fft,
hp.default.win_length,
hp.default.hop_length)

# timesteps
num_timesteps = mfccs.shape[0]
Expand All @@ -64,16 +87,17 @@ def get_mfccs_and_phones(wav_file, sr, length, trim=False, random_crop=True):
assert (len(mfccs) == len(phns))

# Random crop
n_timesteps = (hp.default.duration * hp.default.sr) // hp.default.hop_length + 1
if random_crop:
start = np.random.choice(range(np.maximum(1, len(mfccs) - length)), 1)[0]
end = start + length
start = np.random.choice(range(np.maximum(1, len(mfccs) - n_timesteps)), 1)[0]
end = start + n_timesteps
mfccs = mfccs[start:end]
phns = phns[start:end]
assert (len(mfccs) == len(phns))

# Padding or crop
mfccs = librosa.util.fix_length(mfccs, length, axis=0)
phns = librosa.util.fix_length(phns, length, axis=0)
mfccs = librosa.util.fix_length(mfccs, n_timesteps, axis=0)
phns = librosa.util.fix_length(phns, n_timesteps, axis=0)

return mfccs, phns

Expand All @@ -98,11 +122,10 @@ def get_mfccs_and_spectrogram(wav_file, win_length, hop_length, trim=True, durat
length = hp.default.sr * duration
wav = librosa.util.fix_length(wav, length)

return _get_mfcc_log_spec_and_log_mel_spec(wav, hp.default.preemphasis, hp.default.n_fft, win_length, hop_length)

return _get_mfcc_and_spec(wav, hp.default.preemphasis, hp.default.n_fft, win_length, hop_length)

def _get_mfcc_log_spec_and_log_mel_spec(wav, preemphasis_coeff, n_fft, win_length, hop_length):

def _get_mfcc_and_spec(wav, preemphasis_coeff, n_fft, win_length, hop_length):

# Pre-emphasis
y_preem = preemphasis(wav, coeff=preemphasis_coeff)
Expand All @@ -116,269 +139,17 @@ def _get_mfcc_log_spec_and_log_mel_spec(wav, preemphasis_coeff, n_fft, win_lengt
mel = np.dot(mel_basis, mag) # (n_mels, t) # mel spectrogram

# Get mfccs, amp to db
mag_db = amp_to_db(mag)
mel_db = amp_to_db(mel)
mag_db = amp2db(mag)
mel_db = amp2db(mel)
mfccs = np.dot(librosa.filters.dct(hp.default.n_mfcc, mel_db.shape[0]), mel_db)

# Normalization (0 ~ 1)
mag_db = normalize_0_1(mag_db, hp.default.max_db, hp.default.min_db)
mel_db = normalize_0_1(mel_db, hp.default.max_db, hp.default.min_db)

# Quantization
# bins = np.linspace(0, 1, hp.default.quantize_db)
# mag_db = np.digitize(mag_db, bins)
# mel_db = np.digitize(mel_db, bins)

return mfccs.T, mag_db.T, mel_db.T # (t, n_mfccs), (t, 1+n_fft/2), (t, n_mels)


# Adapted from the `sugartensor` code.
# https://github.com/buriburisuri/sugartensor/blob/master/sugartensor/sg_queue.py
def producer_func(func):
r"""Decorates a function `func` as producer_func.
Args:
func: A function to decorate.
"""

@wraps(func)
def wrapper(inputs, dtypes, capacity, num_threads):
r"""
Args:
inputs: A inputs queue list to enqueue
dtypes: Data types of each tensor
capacity: Queue capacity. Default is 32.
num_threads: Number of threads. Default is 1.
"""

# enqueue function
def enqueue_func(sess, op):
# read data from source queue
data = func(sess.run(inputs))
# create feeder dict
feed_dict = {}
for ph, col in zip(placeholders, data):
feed_dict[ph] = col
# run session
sess.run(op, feed_dict=feed_dict)

# create place holder list
placeholders = []
for dtype in dtypes:
placeholders.append(tf.placeholder(dtype=dtype))

# create FIFO queue
queue = tf.FIFOQueue(capacity, dtypes=dtypes)

# enqueue operation
enqueue_op = queue.enqueue(placeholders)

# create queue runner
runner = _FuncQueueRunner(enqueue_func, queue, [enqueue_op] * num_threads)

# register to global collection
tf.train.add_queue_runner(runner)

# return de-queue operation
return queue.dequeue()

return wrapper


@producer_func
def get_mfccs_and_phones_queue(wav_file):
'''From a single `wav_file` (string), which has been fetched from slice queues,
extracts mfccs (inputs), and phones (target), then enqueue them again.
This is applied in `train1` or `test1` phase.
'''

mfccs, phns = get_mfccs_and_phones(wav_file, hp.default.sr, length=int(hp.default.duration / hp.default.frame_shift + 1))
return mfccs, phns


@producer_func
def get_mfccs_and_spectrogram_queue(wav_file):
'''From `wav_file` (string), which has been fetched from slice queues,
extracts mfccs and spectrogram, then enqueue them again.
This is applied in `train2` or `test2` phase.
'''

mfccs, spec, mel = get_mfccs_and_spectrogram(wav_file, duration=hp.default.duration,
win_length=hp.default.win_length, hop_length=hp.default.hop_length)
return mfccs, spec, mel


def get_batch_queue(mode, batch_size):
'''Loads data and put them in mini batch queues.
mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`.
'''

if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'):
raise Exception("invalid mode={}".format(mode))

with tf.device('/cpu:0'):
# Load data
wav_files = load_data(mode=mode)

# calc total batch count
num_batch = len(wav_files) // batch_size

# Convert to tensor
wav_files = tf.convert_to_tensor(wav_files)

# Create Queues
wav_file, = tf.train.slice_input_producer([wav_files, ], shuffle=True, capacity=128)

if mode in ('train1', 'test1'):
# Get inputs and target
mfcc, ppg = get_mfccs_and_phones_queue(inputs=wav_file,
dtypes=[tf.float32, tf.int32],
capacity=2048,
num_threads=32)

# create batch queues
mfcc, ppg = tf.train.batch([mfcc, ppg],
shapes=[(None, hp.default.n_mfcc), (None,)],
num_threads=32,
batch_size=batch_size,
capacity=batch_size * 32,
dynamic_pad=True)
return mfcc, ppg, num_batch
else:
# Get inputs and target
mfcc, spec, mel = get_mfccs_and_spectrogram_queue(inputs=wav_file,
dtypes=[tf.float32, tf.float32, tf.float32],
capacity=2048,
num_threads=64)

# create batch queues
mfcc, spec, mel = tf.train.batch([mfcc, spec, mel],
shapes=[(None, hp.default.n_mfcc), (None, 1 + hp.default.n_fft // 2),
(None, hp.default.n_mels)],
num_threads=64,
batch_size=batch_size,
capacity=batch_size * 64,
dynamic_pad=True)
return mfcc, spec, mel, num_batch


def get_batch(mode, batch_size):
'''Loads data.
mode: A string. Either `train1` | `test1` | `train2` | `test2` | `convert`.
'''

if mode not in ('train1', 'test1', 'train2', 'test2', 'convert'):
raise Exception("invalid mode={}".format(mode))

with tf.device('/cpu:0'):
# Load data
wav_files = load_data(mode=mode)

target_wavs = sample(wav_files, batch_size)

if mode in ('train1', 'test1'):
mfcc, ppg = map(_get_zero_padded, zip(*map(lambda w: get_mfccs_and_phones(w, hp.default.sr, length=int(hp.default.duration / hp.default.frame_shift + 1)), target_wavs)))
return mfcc, ppg
else:
mfcc, spec, mel = map(_get_zero_padded, zip(*map(
lambda wav_file: get_mfccs_and_spectrogram(wav_file, duration=hp.default.duration,
win_length=hp.default.win_length,
hop_length=hp.default.hop_length), target_wavs)))
return mfcc, spec, mel


# TODO generalize for all mode
def get_wav_batch(mode, batch_size):

with tf.device('/cpu:0'):
# Load data
wav_files = load_data(mode=mode)
wav_file = sample(wav_files, 1)[0]
wav, _ = librosa.load(wav_file, sr=hp.default.sr)

total_duration = hp.default.duration * batch_size
total_len = hp.default.sr * total_duration
wav = librosa.util.fix_length(wav, total_len)

length = hp.default.sr * hp.default.duration
batched = np.reshape(wav, (batch_size, length))

mfcc, spec, mel = map(_get_zero_padded, zip(
*map(lambda w: _get_mfcc_log_spec_and_log_mel_spec(w, hp.default.preemphasis, hp.default.n_fft,
hp.default.win_length, hp.default.hop_length), batched)))
return mfcc, spec, mel


class _FuncQueueRunner(tf.train.QueueRunner):
def __init__(self, func, queue=None, enqueue_ops=None, close_op=None,
cancel_op=None, queue_closed_exception_types=None,
queue_runner_def=None):
# save ad-hoc function
self.func = func
# call super()
super(_FuncQueueRunner, self).__init__(queue, enqueue_ops, close_op, cancel_op,
queue_closed_exception_types, queue_runner_def)

# pylint: disable=broad-except
def _run(self, sess, enqueue_op, coord=None):

if coord:
coord.register_thread(threading.current_thread())
decremented = False
try:
while True:
if coord and coord.should_stop():
break
try:
self.func(sess, enqueue_op) # call enqueue function
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with self._lock:
self._runs_per_session[sess] -= 1
decremented = True
if self._runs_per_session[sess] == 0:
try:
sess.run(self._close_op)
except Exception as e:
# Intentionally ignore errors from close_op.
logging.vlog(1, "Ignored exception: %s", str(e))
return
except Exception as e:
# This catches all other exceptions.
if coord:
coord.request_stop(e)
else:
logging.error("Exception in QueueRunner: %s", str(e))
with self._lock:
self._exceptions_raised.append(e)
raise
finally:
# Make sure we account for all terminations: normal or errors.
if not decremented:
with self._lock:
self._runs_per_session[sess] -= 1


def _get_zero_padded(list_of_arrays):
'''
:param list_of_arrays
:return: zero padded array
'''
batch = []
max_len = 0
for d in list_of_arrays:
max_len = max(len(d), max_len)
for d in list_of_arrays:
num_pad = max_len - len(d)
pad_width = [(0, num_pad)]
for _ in range(d.ndim - 1):
pad_width.append((0, 0))
pad_width = tuple(pad_width)
padded = np.pad(d, pad_width=pad_width, mode="constant", constant_values=0)
batch.append(padded)
return np.array(batch)


phns = ['h#', 'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl',
'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi',
'er', 'ey', 'f', 'g', 'gcl', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh',
Expand All @@ -393,7 +164,3 @@ def load_vocab():
return phn2idx, idx2phn


def load_data(mode):
wav_files = glob.glob(getattr(hp, mode).data_path)

return wav_files
Loading

0 comments on commit 3184a62

Please sign in to comment.