Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging AudioReader, TextReader and ImageReader #166

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ In this repository, the network implementation can be found in <a href="./wavene
TensorFlow needs to be installed before running the training script.
TensorFlow 0.10 and the current `master` version are supported.

In addition, [librosa](https://github.com/librosa/librosa) must be installed for reading and writing audio.
In addition:
[librosa](https://github.com/librosa/librosa) must be installed for reading and writing audio.
[PIL Python Imaging Library](http://www.pythonware.com/products/pil/) must be installed for reading and writing images.

To install the required python packages (except TensorFlow), run
```bash
Expand All @@ -79,7 +81,42 @@ python train.py --help
You can find the configuration of the model parameters in [`wavenet_params.json`](./wavenet_params.json).
These need to stay the same between training and generation.

## Generating audio
You can train and generate not only wav files but also texts and images:

The only thing it has to be done is change two parameters in the [`wavenet_params.json`](./wavenet_params.json):

```
{
....
"raw_type": "Audio",
"file_ext": "*.wav"
}
```

In this way, you can train the model with texts by just copying a folder with texts and setting these two parameters in [`wavenet_params.json`](./wavenet_params.json) to:

```
{
....
"raw_type": "Text",
"file_ext": "*.txt"
}
```
- A version of this WaveNet text generator has been used for poem generation here: [Wavenet for Poem Generation](http://bdp.glia.ca/wavenet-for-poem-generation-preliminary-results/)

For image training:

```
{
....
"raw_type": "Image",
"file_ext": "*.jpg"
}
```

The file_ext parameter can be changed to any pattern like ".gif", ".mp3", etc.

## Generating ouputs

[Example output](https://soundcloud.com/user-731806733/tensorflow-wavenet-500-msec-88k-train-steps)
generated by @jyegerlehner based on speaker 280 from the VCTK corpus.
Expand All @@ -94,15 +131,14 @@ where `model.ckpt-1000` needs to be a previously saved model.
You can find these in the `logdir`.
The `--samples` parameter specifies how many audio samples you would like to generate (16000 corresponds to 1 second by default).

The generated waveform can be played back using TensorBoard, or stored as a
`.wav` file by using the `--wav_out_path` parameter:
The generated waveform can be played back using TensorBoard, or stored as a file by using the `--file_out_path` parameter:
```
python generate.py --wav_out_path=generated.wav --samples 16000 model.ckpt-1000
python generate.py --file_out_path=generated.wav --samples 16000 model.ckpt-1000
```

Passing `--save_every` in addition to `--wav_out_path` will save the in-progress wav file every n samples.
Passing `--save_every` in addition to `--file_out_path` will save the in-progress file every n samples.
```
python generate.py --wav_out_path=generated.wav --save_every 2000 --samples 16000 model.ckpt-1000
python generate.py --file_out_path=generated.wav --save_every 2000 --samples 16000 model.ckpt-1000
```

Fast generation is enabled by default.
Expand Down Expand Up @@ -136,3 +172,5 @@ Currently, there is no conditioning on extra information like the speaker ID.

- [tex-wavenet](https://github.com/Zeta36/tensorflow-tex-wavenet), a WaveNet for text generation.
- [image-wavenet](https://github.com/Zeta36/tensorflow-image-wavenet), a WaveNet for image generation.
- [Wavenet-for-Poem-Generation
](https://github.com/jhave/Wavenet-for-Poem-Generation), a Wavenet algorithm to generate poems.
63 changes: 22 additions & 41 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import tensorflow as tf

from wavenet import WaveNetModel, mu_law_decode, mu_law_encode, audio_reader
from wavenet import WaveNetModel, mu_law_decode, write_output, create_seed_audio

SAMPLES = 16000
TEMPERATURE = 1.0
Expand Down Expand Up @@ -66,15 +66,15 @@ def _ensure_positive_float(f):
default=WAVENET_PARAMS,
help='JSON file with the network parameters')
parser.add_argument(
'--wav_out_path',
'--file_out_path',
type=str,
default=None,
help='Path to output wav file')
help='Path to output generated file')
parser.add_argument(
'--save_every',
type=int,
default=SAVE_EVERY,
help='How many samples before saving in-progress wav')
help='How many samples before saving in-progress')
parser.add_argument(
'--fast_generation',
type=_str_to_bool,
Expand All @@ -87,29 +87,6 @@ def _ensure_positive_float(f):
help='The wav file to start generation from')
return parser.parse_args()


def write_wav(waveform, sample_rate, filename):
y = np.array(waveform)
librosa.output.write_wav(filename, y, sample_rate)
print('Updated wav file at {}'.format(filename))


def create_seed(filename,
sample_rate,
quantization_channels,
window_size=WINDOW,
silence_threshold=SILENCE_THRESHOLD):
audio, _ = librosa.load(filename, sr=sample_rate, mono=True)
audio = audio_reader.trim_silence(audio, silence_threshold)

quantized = mu_law_encode(audio, quantization_channels)
cut_index = tf.cond(tf.size(quantized) < tf.constant(window_size),
lambda: tf.size(quantized),
lambda: tf.constant(window_size))

return quantized[:cut_index]


def main():
args = get_arguments()
started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
Expand Down Expand Up @@ -150,18 +127,21 @@ def main():
print('Restoring model from {}'.format(args.checkpoint))
saver.restore(sess, args.checkpoint)

decode = mu_law_decode(samples, wavenet_params['quantization_channels'])
if wavenet_params['raw_type'] == "Audio":
decode = mu_law_decode(samples, wavenet_params['quantization_channels'])
else:
decode = samples

quantization_channels = wavenet_params['quantization_channels']
if args.wav_seed:
seed = create_seed(args.wav_seed,
seed = create_seed_audio(args.wav_seed,
wavenet_params['sample_rate'],
quantization_channels)
waveform = sess.run(seed).tolist()
else:
waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

if args.fast_generation and args.wav_seed:
if args.fast_generation and args.wav_seed and wavenet_params['raw_type'] == "Audio":
# When using the incremental generation, we need to
# feed in all priming samples one by one before starting the
# actual generation.
Expand Down Expand Up @@ -218,27 +198,28 @@ def main():
last_sample_timestamp = current_sample_timestamp

# If we have partial writing, save the result so far.
if (args.wav_out_path and args.save_every and
if (args.file_out_path and args.save_every and
(step + 1) % args.save_every == 0):
out = sess.run(decode, feed_dict={samples: waveform})
write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
write_output(out, args.file_out_path, wavenet_params['sample_rate'], raw_type=wavenet_params['raw_type'])

# Introduce a newline to clear the carriage return from the progress.
print()

# Save the result as an audio summary.
datestring = str(datetime.now()).replace(' ', 'T')
writer = tf.train.SummaryWriter(logdir)
tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
summaries = tf.merge_all_summaries()
summary_out = sess.run(summaries,
if wavenet_params['raw_type'] == "Audio":
# Save the result as an audio summary.
datestring = str(datetime.now()).replace(' ', 'T')
writer = tf.train.SummaryWriter(logdir)
tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
summaries = tf.merge_all_summaries()
summary_out = sess.run(summaries,
feed_dict={samples: np.reshape(waveform, [-1, 1])})
writer.add_summary(summary_out)
writer.add_summary(summary_out)

# Save the result as a wav file.
if args.wav_out_path:
if args.file_out_path:
out = sess.run(decode, feed_dict={samples: waveform})
write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
write_output(out, args.file_out_path, wavenet_params['sample_rate'], raw_type=wavenet_params['raw_type'])

print('Finished generating. The result can be viewed in TensorBoard.')

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
librosa>=0.4.3
Pillow>=2.2.1
5 changes: 3 additions & 2 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# import librosa

from wavenet import (WaveNetModel, time_to_batch, batch_to_time, causal_conv,
optimizer_factory, mu_law_decode)
optimizer_factory, mu_law_decode, mu_law_encode)

SAMPLE_RATE_HZ = 2000.0 # Hz
TRAIN_ITERATIONS = 400
Expand Down Expand Up @@ -144,7 +144,8 @@ def testEndToEndTraining(self):
# plt.show()

audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
loss = self.net.loss(audio_tensor)
encode_output = mu_law_encode(audio_tensor, QUANTIZATION_CHANNELS)
loss = self.net.loss(encode_output)
optimizer = optimizer_factory[self.optimizer_type](
learning_rate=self.learning_rate, momentum=self.momentum)
trainable = tf.trainable_variables()
Expand Down
34 changes: 17 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import tensorflow as tf
from tensorflow.python.client import timeline

from wavenet import WaveNetModel, AudioReader, optimizer_factory
from wavenet import WaveNetModel, FileReader, optimizer_factory

BATCH_SIZE = 1
DATA_DIRECTORY = './VCTK-Corpus'
DATA_DIRECTORY = './data'
LOGDIR_ROOT = './logdir'
CHECKPOINT_EVERY = 50
NUM_STEPS = int(1e5)
Expand All @@ -45,9 +45,9 @@ def _str_to_bool(s):

parser = argparse.ArgumentParser(description='WaveNet example network')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
help='How many wav files to process at once.')
help='How many raw files to process at once.')
parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,
help='The directory containing the VCTK corpus.')
help='The directory containing the training data.')
parser.add_argument('--store_metadata', type=bool, default=False,
help='Whether to store advanced debugging information '
'(execution time, memory consumption) for use with '
Expand Down Expand Up @@ -202,19 +202,19 @@ def main():
# Create coordinator.
coord = tf.train.Coordinator()

# Load raw waveform from VCTK corpus.
# Load raw waveform files.
with tf.name_scope('create_inputs'):
# Allow silence trimming to be skipped by specifying a threshold near
# zero.
silence_threshold = args.silence_threshold if args.silence_threshold > \
EPSILON else None
reader = AudioReader(
args.data_dir,
coord,
sample_rate=wavenet_params['sample_rate'],
sample_size=args.sample_size,
silence_threshold=args.silence_threshold)
audio_batch = reader.dequeue(args.batch_size)
reader = FileReader(
args.data_dir,
coord,
sample_rate=wavenet_params['sample_rate'],
sample_size=args.sample_size,
silence_threshold=args.silence_threshold,
quantization_channels=wavenet_params['quantization_channels'],
pattern=wavenet_params['file_ext'],
EPSILON=EPSILON,
raw_type=wavenet_params['raw_type'])
input_batch = reader.dequeue(args.batch_size)

# Create network.
net = WaveNetModel(
Expand All @@ -231,7 +231,7 @@ def main():
histograms=args.histograms)
if args.l2_regularization_strength == 0:
args.l2_regularization_strength = None
loss = net.loss(audio_batch, args.l2_regularization_strength)
loss = net.loss(input_batch, args.l2_regularization_strength)
optimizer = optimizer_factory[args.optimizer](
learning_rate=args.learning_rate,
momentum=args.momentum)
Expand Down
7 changes: 5 additions & 2 deletions wavenet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .model import WaveNetModel
from .audio_reader import AudioReader
from .ops import (mu_law_encode, mu_law_decode, time_to_batch,
batch_to_time, causal_conv, optimizer_factory)
from .text_reader import TextReader
from .image_reader import ImageReader
from .ops import (FileReader, mu_law_encode, mu_law_decode, time_to_batch,
batch_to_time, causal_conv, optimizer_factory, write_output,
create_seed_audio)
29 changes: 19 additions & 10 deletions wavenet/audio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
import re
import threading

import librosa
import numpy as np
import tensorflow as tf
from .ops import *


def find_files(directory, pattern='*.wav'):
def find_files(directory, pattern):
'''Recursively finds all files matching the pattern.'''
files = []
for root, dirnames, filenames in os.walk(directory):
Expand All @@ -17,9 +17,9 @@ def find_files(directory, pattern='*.wav'):
return files


def load_generic_audio(directory, sample_rate):
def load_generic_audio(directory, sample_rate, pattern):
'''Generator that yields audio waveforms from the directory.'''
files = find_files(directory)
files = find_files(directory, pattern)
for filename in files:
audio, _ = librosa.load(filename, sr=sample_rate, mono=True)
audio = audio.reshape(-1, 1)
Expand Down Expand Up @@ -59,8 +59,12 @@ def __init__(self,
sample_rate,
sample_size=None,
silence_threshold=None,
queue_size=256):
quantization_channels=256,
queue_size=256,
pattern='*.wav'):
self.audio_dir = audio_dir
self.pattern = pattern
self.quantization_channels = quantization_channels
self.sample_rate = sample_rate
self.coord = coord
self.sample_size = sample_size
Expand All @@ -73,21 +77,26 @@ def __init__(self,
self.enqueue = self.queue.enqueue([self.sample_placeholder])

# TODO Find a better way to check this.
# Checking inside the AudioReader's thread makes it hard to terminate
# the execution of the script, so we do it in the constructor for now.
if not find_files(audio_dir):
# Checking inside the AudioReader's thread makes it
# hard to terminate the execution of the script, so
# we do it in the constructor for now.
if not find_files(audio_dir, self.pattern):
raise ValueError("No audio files found in '{}'.".format(audio_dir))

def dequeue(self, num_elements):
output = self.queue.dequeue_many(num_elements)
return output
# We mu-law encode and quantize the input audioform.
encode_output = mu_law_encode(output, self.quantization_channels)
return encode_output

def thread_main(self, sess):
buffer_ = np.array([])
stop = False
# Go through the dataset multiple times
while not stop:
iterator = load_generic_audio(self.audio_dir, self.sample_rate)
iterator = load_generic_audio(self.audio_dir,
self.sample_rate,
self.pattern)
for audio, filename in iterator:
if self.coord.should_stop():
stop = True
Expand Down
Loading