-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrnn_plausiblewords.py
111 lines (92 loc) · 4.29 KB
/
rnn_plausiblewords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
import tensorflow as tf
import time
from rnn_gru_model import RnnGRUModel
import math
# Model wrapper which provides utilities for running the model to generate
# plausible words and doing word plausibility evaluation using the learned or
# loaded weights
class RnnWordPlausibilityEvaluator:
def __init__(self, logger,
model=None,
ids_from_chars=None,
chars_from_ids=None,
temperature=0.5):
self.logger = logger
self.ids_from_chars = ids_from_chars
self.chars_from_ids = chars_from_ids
self.softmax = tf.keras.layers.Softmax()
if not model:
# load the saved weights
self.model = RnnGRUModel()
self.model.load_weights("base_model_saved_weights")
self.ids_from_chars = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=['*', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'], mask_token=None)
self.chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=self.ids_from_chars.get_vocabulary(), invert=True, mask_token=None)
else:
self.model = model
self.temperature = temperature
def generate_one_step(self, inputs, states=None):
# Convert strings to token IDs.
input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
input_ids = self.ids_from_chars(input_chars).to_tensor()
# Run the model.
# predicted_logits.shape is [batch, char, next_char_logits]
predicted_logits, states = self.model(inputs=input_ids, states=states,
return_state=True)
# Only use the last prediction.
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/self.temperature
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = self.chars_from_ids(predicted_ids)
# Return the characters and model state.
return predicted_chars, states
def p_of_letter(self, normalized_probs, letter):
probs = tf.squeeze(normalized_probs)
index = self.ids_from_chars(tf.convert_to_tensor(letter)).numpy()
return probs[index].numpy()
def is_plausible(self, word, show_work=False):
threshold = 0.3 + 0.2 * max((len(word) - 5), 0)
score = self.evaluate_word(word, show_work)
if show_work:
print("score for {0}={1}, threshold={2}".format(word, score, threshold))
return score >= threshold
def evaluate_word(self, word, show_work=False):
states = None
score = 0
# feed '_' into the model
seed_char = tf.constant(['_'])
input_seed_char = tf.strings.unicode_split(seed_char, 'UTF-8')
this_input_id = self.ids_from_chars(input_seed_char).to_tensor()
predicted_logits, states = self.model(inputs=this_input_id, states=states,
return_state=True)
word = word + '*'
penalty = 0
for char in word:
# get the probability of current char
predicted_logits = predicted_logits[:, -1, :]
normalized_probs = self.softmax(predicted_logits).numpy()
char_prob = self.p_of_letter(normalized_probs, char)
# update score and penalty
score += char_prob
if show_work:
print("prob of", char, "=", char_prob)
if char_prob < 0.001:
penalty += 10
elif char_prob < 0.01:
penalty += 1
# feed current char into model
this_char = tf.constant([char])
input_char = tf.strings.unicode_split(this_char, 'UTF-8')
this_input_id = self.ids_from_chars(input_char).to_tensor()
predicted_logits, states = self.model(inputs=this_input_id, states=states,
return_state=True)
if show_work:
print("penalty=", penalty)
if penalty > 0:
score /= (penalty + 1)
return score