-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #343 from mil-tokyo/sentence-generation
Sentence generation
- Loading branch information
Showing
6 changed files
with
303 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/sh | ||
|
||
mkdir -p output | ||
wget https://github.com/mil-tokyo/webdnn-data/raw/master/models/lstm_text_generation/lstm_text_generation.h5 -O output/lstm_text_generation.h5 | ||
wget https://github.com/mil-tokyo/webdnn-data/raw/master/models/lstm_text_generation/model_setting.json -O output/model_setting.json | ||
python ../../bin/convert_keras.py output/lstm_text_generation.h5 --input_shape '(1,40,57)' --out output --backend webgpu,webassembly |
37 changes: 37 additions & 0 deletions
37
example/text_generation/descriptor_run_text_generation.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
<!DOCTYPE html> | ||
<html> | ||
|
||
<head> | ||
<title>Text generation WebDNN example</title> | ||
<meta charset="utf-8"> | ||
<script src="../../lib/inflate.min.js"></script> | ||
<script src="../../dist/webdnn.js"></script> | ||
<script src="descriptor_run_text_generation.js"></script> | ||
</head> | ||
|
||
<body> | ||
<h1>Text generation WebDNN example</h1> | ||
Generates text like Nietzsche after the seed text.<br> | ||
You have to convert model with script before running this webpage. | ||
<form> | ||
Backend: | ||
<label><input type="radio" name="backend_name" value="" checked>auto</label> | ||
<label><input type="radio" name="backend_name" value="webgpu">webgpu</label> | ||
<label><input type="radio" name="backend_name" value="webassembly">webassembly</label><br> | ||
Framework for model: | ||
<label><input type="radio" name="framework_name" value="keras" checked>Keras</label><br> | ||
<span>Seed text:</span> | ||
<span id="seed_text" contenteditable | ||
style="display: inline-block; border: 1px solid #ccc; width: auto; margin: 0 0.5em; padding: 0 0.5em; font-family: monospace; font-size: 28px; min-width: 10em; white-space: nowrap"></span> | ||
<button id="run_button" type="button" onclick="run_entry(); return false;" disabled>Run</button> | ||
<button id="change_seed" type="button" onclick="run_change_seed(); return false;" disabled>Change seed</button> | ||
<br> | ||
Generated text:<br> | ||
<div id="result" style="font-family: monospace; font-size: 28px;"> | ||
<span id="result_seed"></span><span id="result_generated" style="background-color: pink;"></span> | ||
</div> | ||
<div id="messages"></div> | ||
</form> | ||
</body> | ||
|
||
</html> |
115 changes: 115 additions & 0 deletions
115
example/text_generation/descriptor_run_text_generation.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
'use strict'; | ||
|
||
var metadata = null; | ||
|
||
function run_entry() { | ||
run().then(() => { | ||
log('Run finished'); | ||
}).catch((error) => { | ||
log('Error: ' + error); | ||
}); | ||
} | ||
|
||
function run_change_seed() { | ||
let n_sent = metadata.example_sentences.length; | ||
document.querySelector('input[name=seed_text]').value = metadata.example_sentences[Math.floor(Math.random() * (n_sent + 1))]; | ||
} | ||
|
||
function log(msg) { | ||
let msg_node = document.getElementById('messages'); | ||
msg_node.appendChild(document.createElement('br')); | ||
msg_node.appendChild(document.createTextNode(msg)); | ||
} | ||
|
||
let runners = {}; | ||
|
||
async function prepare_run() { | ||
let backend_name = document.querySelector('input[name=backend_name]:checked').value; | ||
let backend_key = backend_name; | ||
if (!(backend_key in runners)) { | ||
log('Initializing and loading model'); | ||
let runner = await WebDNN.load(`./output`, { backendOrder: backend_name }); | ||
log(`Loaded backend: ${runner.backendName}`); | ||
|
||
runners[backend_key] = runner; | ||
} else { | ||
log('Model is already loaded'); | ||
} | ||
return runners[backend_key]; | ||
} | ||
|
||
function sentence_to_array(sentence) { | ||
let maxlen = metadata.maxlen; | ||
let n_chars = metadata.n_chars; | ||
let array = new Float32Array(1 * maxlen * n_chars);//NTC order | ||
for (let i = 0; i < maxlen; i++) { | ||
let char = sentence[sentence.length - maxlen + i]; | ||
let char_idx = metadata.char_indices[char]; | ||
if (char_idx === void 0) { | ||
char_idx = 0; | ||
} | ||
array[i * n_chars + char_idx] = 1.0; | ||
} | ||
|
||
return array; | ||
} | ||
|
||
function sample_next_char(scores, temperature) { | ||
let probs = new Float32Array(metadata.n_chars); | ||
let prob_sum = 0.0; | ||
for (let i = 0; i < metadata.n_chars; i++) { | ||
let prob = Math.exp(Math.log(scores[i]) / temperature); | ||
prob_sum += prob; | ||
probs[i] = prob; | ||
} | ||
|
||
let char_idx = metadata.n_chars - 1; | ||
let rand = Math.random() * prob_sum; | ||
for (let i = 0; i < metadata.n_chars; i++) { | ||
rand -= probs[i]; | ||
if (rand < 0.0) { | ||
char_idx = i; | ||
break; | ||
} | ||
} | ||
|
||
return metadata.indices_char['' + char_idx]; | ||
} | ||
|
||
async function run() { | ||
let runner = await prepare_run(); | ||
|
||
let sentence_seed = document.querySelector('#seed_text').textContent; | ||
let sentence = sentence_seed; | ||
|
||
for (let i = 0; i < 100; i++) { | ||
// input current sentence to the model | ||
runner.getInputViews()[0].set(sentence_to_array(sentence)); | ||
|
||
// predict next character's probability | ||
await runner.run(); | ||
let out_vec = runner.getOutputViews()[0].toActual(); | ||
// sample next character | ||
let next_char = sample_next_char(out_vec, 1.0); | ||
sentence += next_char; | ||
console.log('output vector: ', out_vec); | ||
} | ||
document.getElementById('result_seed').textContent = sentence_seed; | ||
document.getElementById('result_generated').textContent = sentence.slice(sentence_seed.length); | ||
} | ||
|
||
document.addEventListener('DOMContentLoaded', async function (event) { | ||
try { | ||
let response = await fetch('output/model_setting.json'); | ||
if (!response.ok) { | ||
throw new Error('Metadata HTTP response is not OK'); | ||
} | ||
let json = await response.json(); | ||
metadata = json; | ||
document.querySelector('#seed_text').textContent = metadata['example_sentences'][0]; | ||
document.getElementById('run_button').disabled = false; | ||
document.getElementById('change_seed').disabled = false; | ||
} catch (error) { | ||
log('Failed to load metadata: ' + error); | ||
} | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
''' | ||
This is based on Keras's example. Feature of saving model and setting is added. | ||
https://raw.githubusercontent.com/fchollet/keras/master/examples/lstm_text_generation.py | ||
Trained model can be obtained from | ||
https://github.com/mil-tokyo/webdnn-data/raw/master/models/lstm_text_generation/lstm_text_generation.h5 | ||
Example script to generate text from Nietzsche's writings. | ||
At least 20 epochs are required before the generated text | ||
starts sounding coherent. | ||
It is recommended to run this script on GPU, as recurrent | ||
networks are quite computationally intensive. | ||
If you try this script on new data, make sure your corpus | ||
has at least ~100k characters. ~1M is better. | ||
''' | ||
|
||
from __future__ import print_function | ||
from keras.models import Sequential | ||
from keras.layers import Dense, Activation | ||
from keras.layers import LSTM | ||
from keras.optimizers import RMSprop | ||
from keras.utils.data_utils import get_file | ||
import numpy as np | ||
import random | ||
import sys | ||
import os | ||
import json | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--out", default="output") | ||
args = parser.parse_args() | ||
|
||
path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt') | ||
text = open(path).read().lower() | ||
print('corpus length:', len(text)) | ||
|
||
chars = sorted(list(set(text))) | ||
print('total chars:', len(chars)) | ||
char_indices = dict((c, i) for i, c in enumerate(chars)) | ||
indices_char = dict((i, c) for i, c in enumerate(chars)) | ||
|
||
# cut the text in semi-redundant sequences of maxlen characters | ||
maxlen = 40 | ||
step = 3 | ||
sentences = [] | ||
next_chars = [] | ||
for i in range(0, len(text) - maxlen, step): | ||
sentences.append(text[i: i + maxlen]) | ||
next_chars.append(text[i + maxlen]) | ||
print('nb sequences:', len(sentences)) | ||
|
||
print('Vectorization...') | ||
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool) | ||
y = np.zeros((len(sentences), len(chars)), dtype=np.bool) | ||
for i, sentence in enumerate(sentences): | ||
for t, char in enumerate(sentence): | ||
X[i, t, char_indices[char]] = 1 | ||
y[i, char_indices[next_chars[i]]] = 1 | ||
|
||
# saves char-index mapping | ||
os.makedirs(args.out, exist_ok=True) | ||
with open(os.path.join(args.out, "model_setting.json"), "w") as f: | ||
json.dump({"char_indices": char_indices, | ||
"indices_char": indices_char, | ||
"maxlen": maxlen, | ||
"n_chars": len(chars), | ||
"example_sentences": random.sample(sentences, 100)}, f) | ||
|
||
# build the model: a single LSTM | ||
print('Build model...') | ||
model = Sequential() | ||
model.add(LSTM(128, input_shape=(maxlen, len(chars)))) | ||
model.add(Dense(len(chars))) | ||
model.add(Activation('softmax')) | ||
|
||
optimizer = RMSprop(lr=0.01) | ||
model.compile(loss='categorical_crossentropy', optimizer=optimizer) | ||
|
||
|
||
def sample(preds, temperature=1.0): | ||
# helper function to sample an index from a probability array | ||
preds = np.asarray(preds).astype('float64') | ||
preds = np.log(preds) / temperature | ||
exp_preds = np.exp(preds) | ||
preds = exp_preds / np.sum(exp_preds) | ||
probas = np.random.multinomial(1, preds, 1) | ||
return np.argmax(probas) | ||
|
||
|
||
# train the model, output generated text after each iteration | ||
for iteration in range(1, 60): | ||
print() | ||
print('-' * 50) | ||
print('Iteration', iteration) | ||
model.fit(X, y, | ||
batch_size=128, | ||
epochs=1) | ||
|
||
model.save(os.path.join(args.out, "lstm_text_generation.h5")) | ||
|
||
start_index = random.randint(0, len(text) - maxlen - 1) | ||
|
||
for diversity in [0.2, 0.5, 1.0, 1.2]: | ||
print() | ||
print('----- diversity:', diversity) | ||
|
||
generated = '' | ||
sentence = text[start_index: start_index + maxlen] | ||
generated += sentence | ||
print('----- Generating with seed: "' + sentence + '"') | ||
sys.stdout.write(generated) | ||
|
||
for i in range(400): | ||
x = np.zeros((1, maxlen, len(chars))) | ||
for t, char in enumerate(sentence): | ||
x[0, t, char_indices[char]] = 1. | ||
|
||
preds = model.predict(x, verbose=0)[0] | ||
next_index = sample(preds, diversity) | ||
next_char = indices_char[next_index] | ||
|
||
generated += next_char | ||
sentence = sentence[1:] + next_char | ||
|
||
sys.stdout.write(next_char) | ||
sys.stdout.flush() | ||
print() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters