-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwavegan_generate.py
79 lines (60 loc) · 2.32 KB
/
wavegan_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
"""wavegan_generate.ipynb
"""
dataset = 'digits'
from tensorflow.python.client import device_lib
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
if len(get_available_gpus()) == 0:
for i in range(4):
print('WARNING: Not running on a GPU! See above for faster generation')
# Download model
if dataset == 'digits':
!wget https://s3.amazonaws.com/wavegan-v1/models/sc09.ckpt.index -O model.ckpt.index
!wget https://s3.amazonaws.com/wavegan-v1/models/sc09.ckpt.data-00000-of-00001 -O model.ckpt.data-00000-of-00001
!wget https://s3.amazonaws.com/wavegan-v1/models/sc09_infer.meta -O infer.meta
else:
raise NotImplementedError()
import tensorflow as tf
tf.compat.v1.reset_default_graph()
tf.compat.v1.disable_eager_execution()
saver = tf.compat.v1.train.import_meta_graph('infer.meta')
graph = tf.compat.v1.get_default_graph()
sess = tf.compat.v1.InteractiveSession()
saver.restore(sess, 'model.ckpt')
ngenerate = 64
ndisplay = 4
import numpy as np
import PIL.Image
from IPython.display import display, Audio
import time as time
_z = (np.random.rand(ngenerate, 100) * 2.) - 1.
z = graph.get_tensor_by_name('z:0')
G_z = graph.get_tensor_by_name('G_z:0')[:, :, 0]
G_z_spec = graph.get_tensor_by_name('G_z_spec:0')
start = time.time()
_G_z, _G_z_spec = sess.run([G_z, G_z_spec], {z: _z})
print('Finished! (Took {} seconds)'.format(time.time() - start))
for i in range(ndisplay):
print('-' * 80)
print('Example {}'.format(i))
display(PIL.Image.fromarray(_G_z_spec[i]))
display(Audio(_G_z[i], rate=16000))
interp_a = 0
interp_b = 1
interp_n = 3
_za, _zb = _z[interp_a], _z[interp_b]
_z_interp = []
for i in range(interp_n + 2):
a = i / float(interp_n + 1)
_z_interp.append((1-a) * _za + a * _zb)
flat_pad = graph.get_tensor_by_name('flat_pad:0')
G_z_flat = graph.get_tensor_by_name('G_z_flat:0')[:, 0]
G_z_spec_padded = tf.pad(G_z_spec, [[0, 0], [0, 0], [0, 128]])
G_z_spec_padded = tf.transpose(G_z_spec_padded, [0, 2, 1])
G_z_spec_flat = tf.reshape(G_z_spec_padded, [-1, 256])
G_z_spec_flat = tf.transpose(G_z_spec_flat, [1, 0])[:, :-128]
_G_z_flat, _G_z_spec = sess.run([G_z_flat, G_z_spec_flat], {z: _z_interp, flat_pad: 8192})
display(PIL.Image.fromarray(_G_z_spec))
display(Audio(_G_z_flat, rate=16000))